def synthesize(model, text, sentence, prefix=''): src_pos = np.array([i + 1 for i in range(text.shape[1])]) src_pos = np.stack([src_pos]) src_pos = torch.from_numpy(src_pos).to(device).long() model.to(device) mel, mel_postnet, duration_output, f0_output, energy_output = model( text, src_pos) model.to('cpu') mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() f0_output = f0_output[0].detach().cpu().numpy() energy_output = energy_output[0].detach().cpu().numpy() if not os.path.exists(hp.test_path): os.makedirs(hp.test_path) Audio.tools.inv_mel_spec( mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
f_log.write("\n") return d_l, f_l, e_l, mel_l, mel_p_l if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument('--step', type=int, default=30000) args = parser.parse_args() # Get model model = get_FastSpeech2(args.step).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Load vocoder if hp.vocoder == 'melgan': vocoder = utils.get_melgan() elif hp.vocoder == 'waveglow': vocoder = utils.get_waveglow() vocoder.to(device) # Init directories if not os.path.exists(hp.log_path): os.makedirs(hp.log_path) if not os.path.exists(hp.eval_path): os.makedirs(hp.eval_path) evaluate(model, args.step, vocoder)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model if hp.use_spk_embed: n_pkers = len(dataset.spk_table.keys()) model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device) else: model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() #melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Init evaluation directory eval_path = hp.eval_path if not os.path.exists(eval_path): os.makedirs(eval_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 print("step : {}".format(current_step), end='\r', flush=True) ### Get Data ### if hp.use_spk_embed: spk_ids = torch.tensor(data_of_batch["spk_ids"]).to( torch.int64).to(device) else: spk_ids = None text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) ### Forward ### mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, spk_ids) ### Cal Loss ### mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") ## Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue ### Update weights ### nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() ### Print ### if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) ### Save model ### if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) ### Synth ### if current_step % hp.synth_step == 0: length = mel_len[0].item() print("step: {} , length {}, {}".format( current_step, length, mel_len)) mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) spk_id = dataset.inv_spk_table[int(spk_ids[0])] if hp.vocoder == 'melgan': vocoder.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) elif hp.vocoder == 'waveglow': vocoder.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, hp.vocoder))) vocoder.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) ### Evaluation ### if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() ### Time ### end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Get dataset dataset = Dataset("train.txt") loader = DataLoaderX(dataset, batch_size=hp.batch_size * 4, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=8) # Define model model = nn.DataParallel(FastSpeech2()).to(device) # model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) train_logger = SummaryWriter(log_dir='log/train/' + current_time) val_logger = SummaryWriter(log_dir='log/validation/' + current_time) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() current_step0 = 0 # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * len(batchs) + j + args.restore_step + \ epoch * len(loader)*len(batchs) + 1 # Get Data condition = torch.from_numpy( data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy( data_of_batch["mel_refer"]).float().to(device) if hp.vocoder == 'WORLD': ap_target = torch.from_numpy( data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy( data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) if hp.vocoder == 'WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target, src_mask=src_mask, ap_mask=ap_mask, sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( condition, mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() if hp.vocoder == 'WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: # f_total_loss.write(str(t_l)+"\n") # with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: # f_mel_loss.write(str(m_l)+"\n") # with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: # f_mel_postnet_loss.write(str(m_p_l)+"\n") # with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: # f_d_loss.write(str(d_l)+"\n") # with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: # f_f_loss.write(str(f_l)+"\n") # with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: # f_e_loss.write(str(e_l)+"\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}],Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) if hp.vocoder == 'WORLD': str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l) else: str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "T:{:.1f}s,ETA:{:.1f}s.".format( (Now - Start) / (current_step - current_step0), (total_step - current_step) * np.mean(Time)) print("" + str1 + str2 + str3 + '') # with open(os.path.join(log_path, "log.txt"), "a") as f_log: # f_log.write(str1 + "\n") # f_log.write(str2 + "\n") # f_log.write(str3 + "\n") # f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) if hp.vocoder == 'WORLD': train_logger.add_scalar('Loss/ap_loss', ap_l, current_step) train_logger.add_scalar('Loss/sp_loss', sp_l, current_step) train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l, current_step) else: train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0 or current_step == 20: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0 or current_step == 5: length = mel_len[0].item() if hp.vocoder == 'WORLD': ap_target_torch = ap_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) ap_torch = ap_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_target_torch = sp_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) sp_torch = sp_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_postnet_torch = sp_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) else: mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose( 0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'WORLD': # ap=np.swapaxes(ap,0,1) # sp=np.swapaxes(sp,0,1) wav = utils.world_infer( np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0, 1), f0_output) sf.write( os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) wav = utils.world_infer( np.swapaxes(ap_target_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_target_torch[0].cpu().numpy(), 0, 1), f0) sf.write( os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) utils.plot_data([ (sp_postnet_torch[0].cpu().numpy(), f0_output, energy_output), (sp_target_torch[0].cpu().numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) plt.matshow(sp_postnet_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'sp_postnet_{}.png'.format(current_step))) plt.matshow(ap_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'ap_{}.png'.format(current_step))) plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step))) # plt.matshow(decoder_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step))) plt.cla() fout = open( os.path.join(synth_path, 'D_{}.txt'.format(current_step)), 'w') fout.write( str(log_duration_output[0].detach().cpu().numpy()) + '\n') fout.write(str(D[0].detach().cpu().numpy()) + '\n') fout.write( str(condition[0, :, 2].detach().cpu().numpy()) + '\n') fout.close() # if current_step % hp.eval_step == 0 or current_step==20: # model.eval() # with torch.no_grad(): # if hp.vocoder=='WORLD': # d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step) # val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step) # val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # else: # d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + m_l + m_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/mel_loss', m_l, current_step) # val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # model.train() # if current_step%10==0: # print(energy_output[0],energy[0]) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
parser.add_argument('--energy_control', type=float, default=1.0) args = parser.parse_args() sentences = [ "Advanced text to speech models such as Fast Speech can synthesize speech significantly faster than previous auto regressive models with comparable quality. The training of Fast Speech model relies on an auto regressive teacher model for duration prediction and knowledge distillation, which can ease the one to many mapping problem in T T S. However, Fast Speech has several disadvantages, 1, the teacher student distillation pipeline is complicated, 2, the duration extracted from the teacher model is not accurate enough, and the target mel spectrograms distilled from teacher model suffer from information loss due to data simplification, both of which limit the voice quality.", "Printing, in the only sense with which we are at present concerned, differs from most if not from all the arts and crafts represented in the Exhibition", "in being comparatively modern.", "For although the Chinese took impressions from wood blocks engraved in relief for centuries before the woodcutters of the Netherlands, by a similar process", "produced the block books, which were the immediate predecessors of the true printed book,", "the invention of movable metal letters in the middle of the fifteenth century may justly be considered as the invention of the art of printing.", "And it is worth mention in passing that, as an example of fine typography,", "the earliest book printed with movable types, the Gutenberg, or \"forty-two line Bible\" of about 1455,", "has never been surpassed.", "Printing, then, for our purpose, may be considered as the art of making books by means of movable types.", "Now, as all books not primarily intended as picture-books consist principally of types composed to form letterpress," ] model = get_FastSpeech2(args.step).to(device) melgan = waveglow = None if hp.vocoder == 'melgan': melgan = utils.get_melgan() elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() with torch.no_grad(): for sentence in sentences: text = preprocess(sentence) synthesize(model, waveglow, melgan, text, sentence, 'step_{}'.format(args.step), args.duration_control, args.pitch_control, args.energy_control)
def main(args, device): hp.checkpoint_path = os.path.join(hp.root_path, args.name_task, "ckpt", hp.dataset) hp.synth_path = os.path.join(hp.root_path, args.name_task, "synth", hp.dataset) hp.eval_path = os.path.join(hp.root_path, args.name_task, "eval", hp.dataset) hp.log_path = os.path.join(hp.root_path, args.name_task, "log", hp.dataset) hp.test_path = os.path.join(hp.root_path, args.name_task, 'results') # Define model print("Use FastSpeech") model = nn.DataParallel(FastSpeech()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of TTS Parameters:', num_param) # Get buffer print("Load data to buffer") buffer = get_data_to_buffer('train.txt') # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp1.decoder_dim, hp1.n_warm_up_step, args.restore_step) fastspeech_loss = DNNLoss().to(device) print("Defined Optimizer and Loss Function.") # Load checkpoint if exists checkpoint_path = os.path.join(args.restore_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step %d---\n" % args.restore_step) except: print("\n---Start New Training---\n") checkpoint_path = os.path.join(hp.checkpoint_path) if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Init logger if not os.path.exists(hp.log_path): os.makedirs(hp.log_path) waveglow = None if hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() # Get dataset dataset = BufferDataset(buffer) # Get Training Loader training_loader = DataLoader(dataset, batch_size=hp.batch_expand_size * hp.batch_size, shuffle=True, collate_fn=collate_fn_tensor, drop_last=True, num_workers=0) total_step = hp.epochs * len(training_loader) * hp.batch_expand_size # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): for i, batchs in enumerate(training_loader): # real batch start here for j, db in enumerate(batchs): print(len(batchs), len(db)) start_time = time.perf_counter() current_step = i * hp.batch_expand_size + j + args.restore_step + \ epoch * len(training_loader) * hp.batch_expand_size + 1 # Init scheduled_optim.zero_grad() # Get Data character = db["text"].long().to(device) mel_target = db["mel_target"].float().to(device) duration = db["duration"].int().to(device) mel_pos = db["mel_pos"].long().to(device) src_pos = db["src_pos"].long().to(device) max_mel_len = db["mel_max_len"] # Forward mel_output, mel_postnet_output, duration_predictor_output = model( character, src_pos, mel_pos=mel_pos, mel_max_length=max_mel_len, length_target=duration) # Cal Loss mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss( mel_output, mel_postnet_output, duration_predictor_output, mel_target, duration) total_loss = mel_loss + mel_postnet_loss + duration_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = duration_loss.item() with open(os.path.join(hp.log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(hp.log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(hp.log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(hp.log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp1.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l) str3 = "Current Learning Rate is {:.6f}.".format( scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join(hp.log_path, "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) # if current_step % 10 == 0: # model.eval() # with torch.no_grad(): # t_l, d_l, mel_l, mel_p_l = evaluate( # model, current_step, vocoder=waveglow) # str0 = 'Validating' # str1 = "\tEpoch [{}/{}], Step [{}/{}]:".format( # epoch + 1, hp.epochs, current_step, total_step) # str2 = "\tTotal Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format( # t_l, m_l, m_p_l, d_l) # print(str0) # print(str1) # print(str2) # model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model # model = nn.DataParallel(FastSpeech2()).to(device) model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() # pdb.set_trace() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_output, mel_postnet_output, duration_output, src_mask, pred_mel_mask, enc_attns, dec_attns, W = model( text, src_len, mel_len, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss = Loss( duration_output, mel_len, mel_output, mel_postnet_output, mel_target, src_mask, pred_mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) plot_attn(train_logger, enc_attns, dec_attns, current_step, hp) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) Audio.tools.inv_mel_spec( mel, os.path.join( synth_path, "step_{}_griffin_lim.wav".format(current_step))) Audio.tools.inv_mel_spec( mel_postnet, os.path.join( synth_path, "step_{}_postnet_griffin_lim.wav".format( current_step))) if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, m_l, m_p_l = evaluate(model, current_step) t_l = d_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("test.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0) model = get_FastSpeech2(args.step, full_path=args.model_fs).to(device) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan(full_path=args.model_melgan) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'test')) test_logger = SummaryWriter(os.path.join(log_path, 'test')) # Init synthesis directory test_path = hp.test_path if not os.path.exists(test_path): os.makedirs(test_path) current_step = args.step findex = open(os.path.join(test_path, "index.tsv"), "w") # Testing print("Generate test audio") prefix = "" for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): print("Start batch", j) fids = data_of_batch["id"] # Get Data text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len) for i in range(len(mel_len)): fid = fids[i] print("Generate audio for:", fid) length = mel_len[i].item() _mel_target_torch = mel_target[i, :length].detach().unsqueeze( 0).transpose(1, 2) _mel_target = mel_target[i, :length].detach().cpu().transpose( 0, 1) _mel_torch = mel_output[i, :length].detach().unsqueeze( 0).transpose(1, 2) _mel = mel_output[i, :length].detach().cpu().transpose(0, 1) _mel_postnet_torch = mel_postnet_output[ i, :length].detach().unsqueeze(0).transpose(1, 2) _mel_postnet = mel_postnet_output[ i, :length].detach().cpu().transpose(0, 1) fname = "{}{}_step_{}_gt_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel_target, os.path.join(hp.test_path, fname)) _write_index_line(findex, "Griffin Lim", "vocoder", fname, "", fid) fname = "{}{}_step_{}_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + GL", "tts", fname, "", fid) fname = "{}{}_step_{}_postnet_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel_postnet, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + PN + GL", "tts", fname, "", fid) if hp.vocoder == 'melgan': fname = '{}{}_step_{}_ground-truth_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_target_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "Melgan", "vocoder", fname, "", fid) fname = '{}{}_step_{}_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + Melgan", "tts", fname, "", fid) fname = '{}{}_step_{}_postnet_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_postnet_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + PN + Melgan", "tts", fname, "", fid) elif hp.vocoder == 'waveglow': utils.waveglow_infer( _mel_torch, waveglow, os.path.join( hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder))) utils.waveglow_infer( _mel_postnet_torch, waveglow, os.path.join( hp.test_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( _mel_target_torch, waveglow, os.path.join( hp.test_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)))