def evaluate(model, step, vocoder=None): # Get dataset print('evaluating..') # Get dataset if hp.with_hanzi: dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt', filename_hz = "val_hanzi.txt", vocab_file_hz = 'vocab_hanzi.txt') py_vocab_size = len(dataset.py_vocab) hz_vocab_size = len(dataset.hz_vocab) else: dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt', filename_hz = None, vocab_file_hz = None) py_vocab_size = len(dataset.py_vocab) hz_vocab_size = None loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 bar = tqdm.tqdm_notebook(total=len(dataset)//hp.batch_size) for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): bar.update(1) # Get Data id_ = data_of_batch["id"] text = torch.from_numpy(data_of_batch["text"]).long().to(device) if hp.with_hanzi: hz_text = torch.from_numpy( data_of_batch["hz_text"]).long().to(device) else: hz_text = None mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): mel_output, mel_postnet_output, log_duration_output, src_mask, mel_mask, out_mel_len = model( src_seq=text, src_len=src_len, hz_seq=hz_text,mel_len=mel_len, d_target=D, max_src_len=max_src_len, max_mel_len=max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss = Loss( log_duration_output, log_D, mel_output, mel_postnet_output, mel_target-hp.mel_mean, ~src_mask, ~mel_mask) d_l.append(d_loss.item()) # f_l.append(f_loss.item()) # e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach() mel_target_ = mel_target[k, :gt_length].cpu( ).transpose(0, 1).detach() mel_postnet_torch = mel_postnet_output[k:k + 1, :out_length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[k, :out_length].cpu( ).transpose(0, 1).detach() if hp.vocoder == 'melgan': utils.melgan_infer(mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.melgan_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) # np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format( # basename)), mel_postnet.numpy()+hp.mel_mean) # f0_ = f0[k, :gt_length].detach().cpu().numpy() # energy_ = energy[k, :gt_length].detach().cpu().numpy() # f0_output_ = f0_output[k, # :out_length].detach().cpu().numpy() # energy_output_ = energy_output[k, :out_length].detach( # ).cpu().numpy() utils.plot_data([mel_postnet.numpy()+hp.mel_mean,mel_target_.numpy()], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) # f_l = sum(f_l) / len(f_l) # e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) #str3 = "F0 Loss: {}".format(f_l) # str4 = "Energy Loss: {}".format(e_l) str4 = "Mel Loss: {}".format(mel_l) str5 = "Mel Postnet Loss: {}".format(mel_p_l) str6 = "total Loss: {}".format(mel_p_l+mel_l+d_l) print("\n" + str1) print(str2) # print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") # f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, mel_l, mel_p_l
def evaluate(model, step): torch.manual_seed(0) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader(dataset, batch_size=hp.batch_size*4, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] if hp.vocoder=='WORLD': ap = [] sp_l = [] sp_p_l = [] else: mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data id_ = data_of_batch["id"] condition = torch.from_numpy(data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy(data_of_batch["mel_refer"]).float().to(device) if hp.vocoder=='WORLD': ap_target = torch.from_numpy(data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy(data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy(data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): # Forward if hp.vocoder=='WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output,energy_output, src_mask, ap_mask,sp_mask ,variance_adaptor_output,decoder_output= model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target,src_mask=src_mask, ap_mask=ap_mask,sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output,energy_output, src_mask, mel_mask, _ = model( condition,mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss t_l = total_loss.item() if hp.vocoder=='WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach() mel_postnet_torch = mel_postnet_output[k:k+1, :out_length].transpose(1, 2).detach() if hp.vocoder == 'melgan': utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder=='WORLD': utils.world_infer(mel_postnet_torch.numpy(),f0_output, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) utils.world_infer(mel_target_torch.numpy(),f0, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy()) f0_ = f0[k, :gt_length].detach().cpu().numpy() energy_ = energy[k, :gt_length].detach().cpu().numpy() f0_output_ = f0_output[k, :out_length].detach().cpu().numpy() energy_output_ = energy_output[k, :out_length].detach().cpu().numpy() utils.plot_data([(mel_postnet[0].numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) if hp.vocoder=='WORLD': ap_l = sum(ap_l) / len(ap_l) sp_l = sum(sp_l) / len(sp_l) sp_p_l = sum(sp_p_l) / len(sp_p_l) else: mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, f_l, e_l, mel_l, mel_p_l
def main(args): torch.manual_seed(0) # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Get dataset dataset = Dataset("train.txt") loader = DataLoaderX(dataset, batch_size=hp.batch_size * 4, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=8) # Define model model = nn.DataParallel(FastSpeech2()).to(device) # model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) train_logger = SummaryWriter(log_dir='log/train/' + current_time) val_logger = SummaryWriter(log_dir='log/validation/' + current_time) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() current_step0 = 0 # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * len(batchs) + j + args.restore_step + \ epoch * len(loader)*len(batchs) + 1 # Get Data condition = torch.from_numpy( data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy( data_of_batch["mel_refer"]).float().to(device) if hp.vocoder == 'WORLD': ap_target = torch.from_numpy( data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy( data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) if hp.vocoder == 'WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target, src_mask=src_mask, ap_mask=ap_mask, sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( condition, mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() if hp.vocoder == 'WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: # f_total_loss.write(str(t_l)+"\n") # with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: # f_mel_loss.write(str(m_l)+"\n") # with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: # f_mel_postnet_loss.write(str(m_p_l)+"\n") # with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: # f_d_loss.write(str(d_l)+"\n") # with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: # f_f_loss.write(str(f_l)+"\n") # with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: # f_e_loss.write(str(e_l)+"\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}],Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) if hp.vocoder == 'WORLD': str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l) else: str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "T:{:.1f}s,ETA:{:.1f}s.".format( (Now - Start) / (current_step - current_step0), (total_step - current_step) * np.mean(Time)) print("" + str1 + str2 + str3 + '') # with open(os.path.join(log_path, "log.txt"), "a") as f_log: # f_log.write(str1 + "\n") # f_log.write(str2 + "\n") # f_log.write(str3 + "\n") # f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) if hp.vocoder == 'WORLD': train_logger.add_scalar('Loss/ap_loss', ap_l, current_step) train_logger.add_scalar('Loss/sp_loss', sp_l, current_step) train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l, current_step) else: train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0 or current_step == 20: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0 or current_step == 5: length = mel_len[0].item() if hp.vocoder == 'WORLD': ap_target_torch = ap_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) ap_torch = ap_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_target_torch = sp_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) sp_torch = sp_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_postnet_torch = sp_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) else: mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose( 0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'WORLD': # ap=np.swapaxes(ap,0,1) # sp=np.swapaxes(sp,0,1) wav = utils.world_infer( np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0, 1), f0_output) sf.write( os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) wav = utils.world_infer( np.swapaxes(ap_target_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_target_torch[0].cpu().numpy(), 0, 1), f0) sf.write( os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) utils.plot_data([ (sp_postnet_torch[0].cpu().numpy(), f0_output, energy_output), (sp_target_torch[0].cpu().numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) plt.matshow(sp_postnet_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'sp_postnet_{}.png'.format(current_step))) plt.matshow(ap_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'ap_{}.png'.format(current_step))) plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step))) # plt.matshow(decoder_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step))) plt.cla() fout = open( os.path.join(synth_path, 'D_{}.txt'.format(current_step)), 'w') fout.write( str(log_duration_output[0].detach().cpu().numpy()) + '\n') fout.write(str(D[0].detach().cpu().numpy()) + '\n') fout.write( str(condition[0, :, 2].detach().cpu().numpy()) + '\n') fout.close() # if current_step % hp.eval_step == 0 or current_step==20: # model.eval() # with torch.no_grad(): # if hp.vocoder=='WORLD': # d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step) # val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step) # val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # else: # d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + m_l + m_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/mel_loss', m_l, current_step) # val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # model.train() # if current_step%10==0: # print(energy_output[0],energy[0]) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model if hp.use_spk_embed: n_pkers = len(dataset.spk_table.keys()) model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device) else: model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() #melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Init evaluation directory eval_path = hp.eval_path if not os.path.exists(eval_path): os.makedirs(eval_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 print("step : {}".format(current_step), end='\r', flush=True) ### Get Data ### if hp.use_spk_embed: spk_ids = torch.tensor(data_of_batch["spk_ids"]).to( torch.int64).to(device) else: spk_ids = None text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) ### Forward ### mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, spk_ids) ### Cal Loss ### mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") ## Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue ### Update weights ### nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() ### Print ### if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) ### Save model ### if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) ### Synth ### if current_step % hp.synth_step == 0: length = mel_len[0].item() print("step: {} , length {}, {}".format( current_step, length, mel_len)) mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) spk_id = dataset.inv_spk_table[int(spk_ids[0])] if hp.vocoder == 'melgan': vocoder.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) elif hp.vocoder == 'waveglow': vocoder.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, hp.vocoder))) vocoder.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) ### Evaluation ### if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() ### Time ### end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader( dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get model model = get_FastSpeech2(args.step).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Init directories if not os.path.exists(hp.logger_path): os.makedirs(hp.logger_path) if not os.path.exists(hp.eval_path): os.makedirs(hp.eval_path) # Get loss function Loss = FastSpeech2Loss().to(device) print("Loss Function Defined.") # Load vocoder wave_glow = utils.get_WaveGlow() # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) with torch.no_grad(): # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target) d_l.append(d_loss.item()) f_l.append(f_loss.item()) e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) for k in range(len(mel_target)): length = mel_len[k] mel_target_torch = mel_target[k:k + 1, :length].transpose( 1, 2).detach() mel_target_ = mel_target[k, :length].cpu().transpose( 0, 1).detach() waveglow.inference.inference( mel_target_torch, wave_glow, os.path.join( hp.eval_path, 'ground-truth_{}_waveglow.wav'.format(idx))) mel_postnet_torch = mel_postnet_output[ k:k + 1, :length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[ k, :length].cpu().transpose(0, 1).detach() waveglow.inference.inference( mel_postnet_torch, wave_glow, os.path.join(hp.eval_path, 'eval_{}_waveglow.wav'.format(idx))) utils.plot_data([ (mel_postnet.numpy(), None, None), (mel_target_.numpy(), None, None) ], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( hp.eval_path, 'eval_{}.png'.format(idx))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(args.step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.logger_path, "eval.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write(str5 + "\n") f_logger.write(str6 + "\n") f_logger.write("\n")
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # read params mean_mel, std_mel = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) mean_f0, std_f0 = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) mean_energy, std_energy = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1) mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1) mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape( 1, -1) # Load vocoder if hp.vocoder == 'vocgan': vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path) vocoder.to(device) else: vocoder = None # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step, vocoder) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def evaluate(model, step, vocoder=None): model.eval() torch.manual_seed(0) mean_mel, std_mel = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) mean_f0, std_f0 = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) mean_energy, std_energy = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) eval_path = hp.eval_path if not os.path.exists(eval_path): os.makedirs(eval_path) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader( dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data id_ = data_of_batch["id"] text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): # Forward mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) d_l.append(d_loss.item()) f_l.append(f_loss.item()) e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if idx == 0 and vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(1): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k + 1, :gt_length] mel_target_ = mel_target[k, :gt_length] mel_postnet_torch = mel_postnet_output[k:k + 1, :out_length] mel_postnet = mel_postnet_output[k, :out_length] mel_target_torch = utils.de_norm( mel_target_torch, mean_mel, std_mel).transpose(1, 2).detach() mel_target_ = utils.de_norm(mel_target_, mean_mel, std_mel).cpu().transpose( 0, 1).detach() mel_postnet_torch = utils.de_norm( mel_postnet_torch, mean_mel, std_mel).transpose(1, 2).detach() mel_postnet = utils.de_norm(mel_postnet, mean_mel, std_mel).cpu().transpose( 0, 1).detach() if hp.vocoder == "vocgan": utils.vocgan_infer( mel_target_torch, vocoder, path=os.path.join( hp.eval_path, 'eval_groundtruth_{}_{}.wav'.format( basename, hp.vocoder))) utils.vocgan_infer(mel_postnet_torch, vocoder, path=os.path.join( hp.eval_path, 'eval_{}_{}_{}.wav'.format( step, basename, hp.vocoder))) np.save( os.path.join( hp.eval_path, 'eval_step_{}_{}_mel.npy'.format( step, basename)), mel_postnet.numpy()) f0_ = f0[k, :gt_length] energy_ = energy[k, :gt_length] f0_output_ = f0_output[k, :out_length] energy_output_ = energy_output[k, :out_length] f0_ = utils.de_norm(f0_, mean_f0, std_f0).detach().cpu().numpy() f0_output_ = utils.de_norm( f0_output, mean_f0, std_f0).detach().cpu().numpy() energy_ = utils.de_norm( energy_, mean_energy, std_energy).detach().cpu().numpy() energy_output_ = utils.de_norm( energy_output_, mean_energy, std_energy).detach().cpu().numpy() utils.plot_data( [(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], [ 'Synthesized Spectrogram', 'Ground-Truth Spectrogram' ], filename=os.path.join( hp.eval_path, 'eval_step_{}_{}.png'.format(step, basename))) idx += 1 print("done") current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") model.train() return d_l, f_l, e_l, mel_l, mel_p_l
collate_fn=dataset.collate_fn, drop_last=False, num_workers=8) # Define model model = FastSpeech2(py_vocab_size, hz_vocab_size).to(device) num_param = utils.get_param_num(model) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=hp.start_lr, betas=hp.betas, eps=hp.eps, weight_decay=0) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) #temp = nn.DataParallel(model) model.load_state_dict(checkpoint['model']) #model.load_state_dict(temp.module.state_dict()) #del temp optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except:
def evaluate(model, step, vocoder=None): torch.manual_seed(0) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader( dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data id_ = data_of_batch["id"] text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): # Forward mel_output, mel_postnet_output, duration_output, src_mask, pred_mel_mask, enc_attns, dec_attns, W = model( text, src_len, mel_len, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss = Loss( duration_output, mel_len, mel_output, mel_postnet_output, mel_target, src_mask, pred_mel_mask) d_l.append(d_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k + 1, :gt_length].transpose( 1, 2).detach() mel_target_ = mel_target[ k, :gt_length].cpu().transpose(0, 1).detach() mel_postnet_torch = mel_postnet_output[ k:k + 1, :out_length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[ k, :out_length].cpu().transpose(0, 1).detach() if hp.vocoder == 'melgan': utils.melgan_infer( mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format( basename, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format( basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format( basename, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format( basename, hp.vocoder))) np.save( os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy()) f0_ = f0[k, :gt_length].detach().cpu().numpy() energy_ = energy[k, :gt_length].detach().cpu().numpy() f0_output_ = f0_output[ k, :out_length].detach().cpu().numpy() energy_output_ = energy_output[ k, :out_length].detach().cpu().numpy() utils.plot_data( [(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], [ 'Synthesized Spectrogram', 'Ground-Truth Spectrogram' ], filename=os.path.join( hp.eval_path, 'eval_{}.png'.format(basename))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, mel_l, mel_p_l
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder wave_glow = utils.get_WaveGlow() # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) logger = SummaryWriter(log_path) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D, f0, energy) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) Audio.tools.inv_mel_spec( mel, os.path.join( synth_path, "step_{}_griffin_lim.wav".format(current_step))) Audio.tools.inv_mel_spec( mel_postnet, os.path.join( synth_path, "step_{}_postnet_griffin_lim.wav".format( current_step))) waveglow.inference.inference( mel_torch, wave_glow, os.path.join( synth_path, "step_{}_waveglow.wav".format(current_step))) waveglow.inference.inference( mel_postnet_torch, wave_glow, os.path.join( synth_path, "step_{}_postnet_waveglow.wav".format( current_step))) waveglow.inference.inference( mel_target_torch, wave_glow, os.path.join( synth_path, "step_{}_ground-truth_waveglow.wav".format( current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def evaluate(model, step, wave_glow=None): torch.manual_seed(0) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader( dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) with torch.no_grad(): # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len) d_l.append(d_loss.item()) f_l.append(f_loss.item()) e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if wave_glow is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): length = mel_len[k] mel_target_torch = mel_target[k:k + 1, :length].transpose( 1, 2).detach() mel_target_ = mel_target[k, :length].cpu().transpose( 0, 1).detach() waveglow.inference.inference( mel_target_torch, wave_glow, os.path.join( hp.eval_path, 'ground-truth_{}_waveglow.wav'.format(idx))) mel_postnet_torch = mel_postnet_output[ k:k + 1, :length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[ k, :length].cpu().transpose(0, 1).detach() waveglow.inference.inference( mel_postnet_torch, wave_glow, os.path.join(hp.eval_path, 'eval_{}_waveglow.wav'.format(idx))) f0_ = f0[k, :length].detach().cpu().numpy() energy_ = energy[k, :length].detach().cpu().numpy() f0_output_ = f0_output[ k, :length].detach().cpu().numpy() energy_output_ = energy_output[ k, :length].detach().cpu().numpy() utils.plot_data( [(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], [ 'Synthesized Spectrogram', 'Ground-Truth Spectrogram' ], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(idx))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, f_l, e_l, mel_l, mel_p_l