def synthesize(model, waveglow, melgan, text, sentence, prefix=''): sentence = sentence[:200] # long filename will result in OS Error src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) with torch.no_grad(): mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model(text, src_len) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() f0_output = f0_output[0].detach().cpu().numpy() energy_output = energy_output[0].detach().cpu().numpy() if not os.path.exists(hp.test_path): os.makedirs(hp.test_path) Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) if waveglow is not None: utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) if melgan is not None: utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
def synthesize(model, waveglow, melgan, text, sentence, prefix=''): sentence = sentence[:10] # long filename will result in OS Error mean_mel, std_mel = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) mean_f0, std_f0 = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) mean_energy, std_energy = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1) mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1) mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape( 1, -1) src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model( text, src_len) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() f0_output = f0_output[0] energy_output = energy_output[0] mel_torch = utils.de_norm(mel_torch.transpose(1, 2), mean_mel, std_mel) mel_postnet_torch = utils.de_norm(mel_postnet_torch.transpose(1, 2), mean_mel, std_mel).transpose(1, 2) f0_output = utils.de_norm(f0_output, mean_f0, std_f0).squeeze().detach().cpu().numpy() energy_output = utils.de_norm(energy_output, mean_energy, std_energy).squeeze().detach().cpu().numpy() if not os.path.exists(hp.test_path): os.makedirs(hp.test_path) Audio.tools.inv_mel_spec( mel_postnet_torch[0], os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) if waveglow is not None: utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) if melgan is not None: utils.melgan_infer( mel_postnet_torch, melgan, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) utils.plot_data([ (mel_postnet_torch[0].detach().cpu().numpy(), f0_output, energy_output) ], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
def synthesize(model, waveglow, py_text_seq, cn_text_seq, duration_control=1.0,prefix=''): #sentence = sentence[:200] # long filename will result in OS Error src_len = torch.from_numpy(np.array([py_text_seq.shape[1]])).to(device) mel, mel_postnet, log_duration_output, _, _, mel_len = model( py_text_seq, src_len, hz_seq=cn_text_seq,d_control=duration_control) # print(log_duration_output) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() dst_name = os.path.join( '/dev/shm/', '{}-out.wav'.format(prefix)) utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, waveglow, dst_name) return dst_name
def synthesize(model, text, sentence, prefix=''): src_pos = np.array([i + 1 for i in range(text.shape[1])]) src_pos = np.stack([src_pos]) src_pos = torch.from_numpy(src_pos).to(device).long() model.to(device) mel, mel_postnet, duration_output, f0_output, energy_output = model( text, src_pos) model.to('cpu') mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() f0_output = f0_output[0].detach().cpu().numpy() energy_output = energy_output[0].detach().cpu().numpy() if not os.path.exists(hp.test_path): os.makedirs(hp.test_path) Audio.tools.inv_mel_spec( mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join(hp.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence))) utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join(hp.test_path, '{}_{}.png'.format(prefix, sentence)))
def synthesize(model, waveglow, melgan, text, sentence, prefix='', duration_control=1.0, pitch_control=1.0, energy_control=1.0, output_dir=None): src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model( text, src_len, d_control=duration_control, p_control=pitch_control, e_control=energy_control) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() f0_output = f0_output[0].detach().cpu().numpy() energy_output = energy_output[0].detach().cpu().numpy() if not output_dir: output_dir = hp.test_path if not os.path.exists(output_dir): os.makedirs(output_dir) gl_fname = '{}_griffin_lim.wav'.format(prefix) Audio.tools.inv_mel_spec(mel_postnet, os.path.join(hp.test_path, gl_fname)) vocoder_fname = '{}_{}.wav'.format(prefix, hp.vocoder) if waveglow is not None: utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(output_dir, vocoder_fname)) if melgan is not None: utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(output_dir, vocoder_fname))
def synthesize(model, waveglow, text, idx, prefix='', duration_control=1.0, pitch_control=1.0, energy_control=1.0): t = time() src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model( text, src_len, d_control=duration_control, p_control=pitch_control, e_control=energy_control) # mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() # mel = mel[0].cpu().transpose(0, 1).detach() # mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() # f0_output = f0_output[0].detach().cpu().numpy() # energy_output = energy_output[0].detach().cpu().numpy() if not os.path.exists(args.test_path): os.makedirs(args.test_path) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join( # hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, name))) t1 = time() - t if waveglow is not None: utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join(args.test_path, '{}_{}_{}.wav'.format(prefix, hp.vocoder, idx))) t2 = time() - t print('{}: time FS: {} (s) time {}: {}'.format(idx, t1, hp.vocoder, t2 - t1))
def evaluate(model, step): torch.manual_seed(0) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader(dataset, batch_size=hp.batch_size*4, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] if hp.vocoder=='WORLD': ap = [] sp_l = [] sp_p_l = [] else: mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): # Get Data id_ = data_of_batch["id"] condition = torch.from_numpy(data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy(data_of_batch["mel_refer"]).float().to(device) if hp.vocoder=='WORLD': ap_target = torch.from_numpy(data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy(data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy(data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): # Forward if hp.vocoder=='WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output,energy_output, src_mask, ap_mask,sp_mask ,variance_adaptor_output,decoder_output= model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target,src_mask=src_mask, ap_mask=ap_mask,sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output,energy_output, src_mask, mel_mask, _ = model( condition,mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss t_l = total_loss.item() if hp.vocoder=='WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach() mel_postnet_torch = mel_postnet_output[k:k+1, :out_length].transpose(1, 2).detach() if hp.vocoder == 'melgan': utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder=='WORLD': utils.world_infer(mel_postnet_torch.numpy(),f0_output, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) utils.world_infer(mel_target_torch.numpy(),f0, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy()) f0_ = f0[k, :gt_length].detach().cpu().numpy() energy_ = energy[k, :gt_length].detach().cpu().numpy() f0_output_ = f0_output[k, :out_length].detach().cpu().numpy() energy_output_ = energy_output[k, :out_length].detach().cpu().numpy() utils.plot_data([(mel_postnet[0].numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) if hp.vocoder=='WORLD': ap_l = sum(ap_l) / len(ap_l) sp_l = sum(sp_l) / len(sp_l) sp_p_l = sum(sp_p_l) / len(sp_p_l) else: mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, f_l, e_l, mel_l, mel_p_l
def evaluate(model, step, vocoder=None): # Get dataset print('evaluating..') # Get dataset if hp.with_hanzi: dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt', filename_hz = "val_hanzi.txt", vocab_file_hz = 'vocab_hanzi.txt') py_vocab_size = len(dataset.py_vocab) hz_vocab_size = len(dataset.hz_vocab) else: dataset = Dataset(filename_py="val_pinyin.txt",vocab_file_py = 'vocab_pinyin.txt', filename_hz = None, vocab_file_hz = None) py_vocab_size = len(dataset.py_vocab) hz_vocab_size = None loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 bar = tqdm.tqdm_notebook(total=len(dataset)//hp.batch_size) for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): bar.update(1) # Get Data id_ = data_of_batch["id"] text = torch.from_numpy(data_of_batch["text"]).long().to(device) if hp.with_hanzi: hz_text = torch.from_numpy( data_of_batch["hz_text"]).long().to(device) else: hz_text = None mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): mel_output, mel_postnet_output, log_duration_output, src_mask, mel_mask, out_mel_len = model( src_seq=text, src_len=src_len, hz_seq=hz_text,mel_len=mel_len, d_target=D, max_src_len=max_src_len, max_mel_len=max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss = Loss( log_duration_output, log_D, mel_output, mel_postnet_output, mel_target-hp.mel_mean, ~src_mask, ~mel_mask) d_l.append(d_loss.item()) # f_l.append(f_loss.item()) # e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k+1, :gt_length].transpose(1, 2).detach() mel_target_ = mel_target[k, :gt_length].cpu( ).transpose(0, 1).detach() mel_postnet_torch = mel_postnet_output[k:k + 1, :out_length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[k, :out_length].cpu( ).transpose(0, 1).detach() if hp.vocoder == 'melgan': utils.melgan_infer(mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.melgan_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format(basename, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch+hp.mel_mean, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) # np.save(os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format( # basename)), mel_postnet.numpy()+hp.mel_mean) # f0_ = f0[k, :gt_length].detach().cpu().numpy() # energy_ = energy[k, :gt_length].detach().cpu().numpy() # f0_output_ = f0_output[k, # :out_length].detach().cpu().numpy() # energy_output_ = energy_output[k, :out_length].detach( # ).cpu().numpy() utils.plot_data([mel_postnet.numpy()+hp.mel_mean,mel_target_.numpy()], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, 'eval_{}.png'.format(basename))) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) # f_l = sum(f_l) / len(f_l) # e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) #str3 = "F0 Loss: {}".format(f_l) # str4 = "Energy Loss: {}".format(e_l) str4 = "Mel Loss: {}".format(mel_l) str5 = "Mel Postnet Loss: {}".format(mel_p_l) str6 = "total Loss: {}".format(mel_p_l+mel_l+d_l) print("\n" + str1) print(str2) # print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") # f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, mel_l, mel_p_l
def main(args): torch.manual_seed(0) # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Get dataset dataset = Dataset("train.txt") loader = DataLoaderX(dataset, batch_size=hp.batch_size * 4, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=8) # Define model model = nn.DataParallel(FastSpeech2()).to(device) # model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) train_logger = SummaryWriter(log_dir='log/train/' + current_time) val_logger = SummaryWriter(log_dir='log/validation/' + current_time) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() current_step0 = 0 # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * len(batchs) + j + args.restore_step + \ epoch * len(loader)*len(batchs) + 1 # Get Data condition = torch.from_numpy( data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy( data_of_batch["mel_refer"]).float().to(device) if hp.vocoder == 'WORLD': ap_target = torch.from_numpy( data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy( data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) if hp.vocoder == 'WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target, src_mask=src_mask, ap_mask=ap_mask, sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( condition, mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() if hp.vocoder == 'WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: # f_total_loss.write(str(t_l)+"\n") # with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: # f_mel_loss.write(str(m_l)+"\n") # with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: # f_mel_postnet_loss.write(str(m_p_l)+"\n") # with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: # f_d_loss.write(str(d_l)+"\n") # with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: # f_f_loss.write(str(f_l)+"\n") # with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: # f_e_loss.write(str(e_l)+"\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}],Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) if hp.vocoder == 'WORLD': str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l) else: str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "T:{:.1f}s,ETA:{:.1f}s.".format( (Now - Start) / (current_step - current_step0), (total_step - current_step) * np.mean(Time)) print("" + str1 + str2 + str3 + '') # with open(os.path.join(log_path, "log.txt"), "a") as f_log: # f_log.write(str1 + "\n") # f_log.write(str2 + "\n") # f_log.write(str3 + "\n") # f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) if hp.vocoder == 'WORLD': train_logger.add_scalar('Loss/ap_loss', ap_l, current_step) train_logger.add_scalar('Loss/sp_loss', sp_l, current_step) train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l, current_step) else: train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0 or current_step == 20: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0 or current_step == 5: length = mel_len[0].item() if hp.vocoder == 'WORLD': ap_target_torch = ap_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) ap_torch = ap_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_target_torch = sp_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) sp_torch = sp_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_postnet_torch = sp_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) else: mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose( 0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'WORLD': # ap=np.swapaxes(ap,0,1) # sp=np.swapaxes(sp,0,1) wav = utils.world_infer( np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0, 1), f0_output) sf.write( os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) wav = utils.world_infer( np.swapaxes(ap_target_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_target_torch[0].cpu().numpy(), 0, 1), f0) sf.write( os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) utils.plot_data([ (sp_postnet_torch[0].cpu().numpy(), f0_output, energy_output), (sp_target_torch[0].cpu().numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) plt.matshow(sp_postnet_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'sp_postnet_{}.png'.format(current_step))) plt.matshow(ap_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'ap_{}.png'.format(current_step))) plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step))) # plt.matshow(decoder_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step))) plt.cla() fout = open( os.path.join(synth_path, 'D_{}.txt'.format(current_step)), 'w') fout.write( str(log_duration_output[0].detach().cpu().numpy()) + '\n') fout.write(str(D[0].detach().cpu().numpy()) + '\n') fout.write( str(condition[0, :, 2].detach().cpu().numpy()) + '\n') fout.close() # if current_step % hp.eval_step == 0 or current_step==20: # model.eval() # with torch.no_grad(): # if hp.vocoder=='WORLD': # d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step) # val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step) # val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # else: # d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + m_l + m_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/mel_loss', m_l, current_step) # val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # model.train() # if current_step%10==0: # print(energy_output[0],energy[0]) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def evaluate(model, step, vocoder=None): torch.manual_seed(0) if not os.path.exists(hp.eval_path): os.makedirs(hp.eval_path) # Get dataset dataset = Dataset("val.txt", sort=False) loader = DataLoader(dataset, batch_size=hp.batch_size ** 2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0, ) # Get loss function Loss = FastSpeech2Loss().to(device) # Evaluation d_l = [] f_l = [] e_l = [] mel_l = [] mel_p_l = [] current_step = 0 idx = 0 for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): if j == 1: break # Get Data id_ = data_of_batch["id"] text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy(data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) with torch.no_grad(): # Forward mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, out_mel_len = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) d_l.append(d_loss.item()) f_l.append(f_loss.item()) e_l.append(e_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_phone_length = src_len[k] gt_length = mel_len[k] out_length = out_mel_len[k] mel_target_torch = mel_target[k:k + 1, :gt_length].transpose(1, 2).detach() mel_target_ = mel_target[k, :gt_length].cpu().transpose(0, 1).detach() mel_postnet_torch = mel_postnet_output[k:k + 1, :out_length].transpose(1, 2).detach() mel_postnet = mel_postnet_output[k, :out_length].cpu().transpose(0, 1).detach() if hp.vocoder == 'melgan': utils.melgan_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, '{}_ground-truth_{}.wav'.format( basename, hp.vocoder))) utils.melgan_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, '{}_eval_{}.wav'.format( basename, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_target_torch, vocoder, os.path.join(hp.eval_path, 'ground-truth_{}_{}.wav'.format( basename, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch, vocoder, os.path.join(hp.eval_path, 'eval_{}_{}.wav'.format( basename, hp.vocoder))) np.save(os.path.join(hp.eval_path, '{}_eval_mel.npy'.format(basename)), mel_postnet.numpy()) f0_ = f0[k, :gt_length].detach().cpu().numpy() energy_ = energy[k, :gt_length].detach().cpu().numpy() f0_output_ = f0_output[k, :out_length].detach().cpu().numpy() energy_output_ = energy_output[k, :out_length].detach().cpu().numpy() d_ = D[k, :gt_phone_length].detach().cpu().numpy() log_d_output_ = log_duration_output[k, :gt_phone_length].detach().cpu().numpy() d_output_ = np.exp(log_d_output_) - hp.log_offset utils.plot_data( [(mel_postnet.numpy(), f0_output_, energy_output_), (mel_target_.numpy(), f0_, energy_)], ['Synthesized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(hp.eval_path, '{}_eval.png'.format(basename)), ) utils.plot_duration( [d_output_, d_], ["Synthesized Duration", "Ground-Truth"], filename=os.path.join(hp.eval_path, '{}_eval_dur.png'.format(basename)) ) idx += 1 current_step += 1 d_l = sum(d_l) / len(d_l) f_l = sum(f_l) / len(f_l) e_l = sum(e_l) / len(e_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Duration Loss: {}".format(d_l) str3 = "F0 Loss: {}".format(f_l) str4 = "Energy Loss: {}".format(e_l) str5 = "Mel Loss: {}".format(mel_l) str6 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) print(str6) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write(str6 + "\n") f_log.write("\n") return d_l, f_l, e_l, mel_l, mel_p_l
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay = hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load(os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) logger = SummaryWriter(log_path) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i*hp.batch_size + j + args.restore_step + epoch*len(loader)*hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy(data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy(data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy(data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy(data_of_batch["src_pos"]).long().to(device) src_len = torch.from_numpy(data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy(data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D, f0, energy) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l)+"\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l)+"\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l)+"\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l)+"\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l)+"\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l)+"\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch+1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now-Start), (total_step-current_step)*np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step) if current_step % hp.save_step == 0: torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict( )}, os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze(0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[0, :length].detach().cpu().transpose(0, 1) Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) if hp.vocoder == 'melgan': utils.melgan_infer(mel_torch, melgan, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder))) utils.melgan_infer(mel_postnet_torch, melgan, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder))) utils.melgan_infer(mel_target_torch, melgan, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer(mel_torch, waveglow, os.path.join(hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder))) utils.waveglow_infer(mel_postnet_torch, waveglow, os.path.join(hp.test_path, 'step_{}_postnet_{}.wav'.format(current_step, hp.vocoder))) utils.waveglow_infer(mel_target_torch, waveglow, os.path.join(hp.test_path, 'step_{}_ground-truch_{}.wav'.format(current_step, hp.vocoder))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[0, :length].detach().cpu().numpy() utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy)], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step))) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete( Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def synthesize(model, waveglow, melgan, text, sentence, prefix=''): sentence = sentence[:150] # long filename will result in OS Error src_len = torch.from_numpy(np.array([text.shape[1]])).to(device) # create dir if not os.path.exists(os.path.join(hp.test_path, hp.dataset)): os.makedirs(os.path.join(hp.test_path, hp.dataset)) # generate wav if hp.use_spk_embed: hp.batch_size = 3 # select speakers # TODO spk_ids = torch.tensor( list(inv_spk_table.keys())[5:5 + hp.batch_size]).to( torch.int64).to(device) text = text.repeat(hp.batch_size, 1) src_len = src_len.repeat(hp.batch_size) mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model( text, src_len, speaker_ids=spk_ids) mel_mask = get_mask_from_lengths(mel_len, None) mel_mask = mel_mask.unsqueeze(-1).expand(mel_postnet.size()) silence = (torch.ones(mel_postnet.size()) * -5).to(device) mel = torch.where(~mel_mask, mel, silence) mel_postnet = torch.where(~mel_mask, mel_postnet, silence) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() if waveglow is not None: wavs = utils.waveglow_infer_batch(mel_postnet_torch, waveglow) if melgan is not None: wavs = utils.melgan_infer_batch(mel_postnet_torch, melgan) for i, spk_id in enumerate(spk_ids): spker = inv_spk_table[int(spk_id)] mel_postnet_i = mel_postnet[i].cpu().transpose(0, 1).detach() f0_i = f0_output[i].detach().cpu().numpy() energy_i = energy_output[i].detach().cpu().numpy() mel_mask_i = mel_mask[i] wav_i = wavs[i] # output base_dir_i = os.path.join(hp.test_path, hp.dataset, "step {}".format(args.step), spker) os.makedirs(base_dir_i, exist_ok=True) path_i = os.path.join( base_dir_i, '{}_{}_{}.wav'.format(prefix, hp.vocoder, sentence)) soundfile.write(path_i, wav_i, hp.sampling_rate) utils.plot_data([(mel_postnet_i.numpy(), f0_i, energy_i)], ['Synthesized Spectrogram'], filename=os.path.join( base_dir_i, '{}_{}.png'.format(prefix, sentence))) else: spk_ids = None mel, mel_postnet, log_duration_output, f0_output, energy_output, _, _, mel_len = model( text, src_len, speaker_ids=spk_ids) mel_torch = mel.transpose(1, 2).detach() mel_postnet_torch = mel_postnet.transpose(1, 2).detach() mel = mel[0].cpu().transpose(0, 1).detach() mel_postnet = mel_postnet[0].cpu().transpose(0, 1).detach() f0_output = f0_output[0].detach().cpu().numpy() energy_output = energy_output[0].detach().cpu().numpy() Audio.tools.inv_mel_spec( mel_postnet, os.path.join(hp.test_path, '{}_griffin_lim_{}.wav'.format(prefix, sentence))) if waveglow is not None: utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.test_path, hp.dataset, '{}_{}_{}_{}.wav'.format(prefix, hp.vocoder, spker, sentence))) if melgan is not None: utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.test_path, hp.dataset, '{}_{}_{}_{}.wav'.format(prefix, hp.vocoder, spker, sentence))) utils.plot_data([(mel_postnet.numpy(), f0_output, energy_output)], ['Synthesized Spectrogram'], filename=os.path.join( hp.test_path, '{}_{}.png'.format(prefix, sentence)))
mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': # utils.waveglow_infer(mel_torch, waveglow, os.path.join( # hp.synth_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch + hp.mel_mean, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) utils.plot_data( [mel_postnet.numpy() + hp.mel_mean, mel_target.numpy()], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join(synth_path, 'step_{}.png'.format(current_step)))
def evaluate(model, step, vocoder=None): torch.manual_seed(0) # Get dataset print("Load data to buffer") buffer = get_data_to_buffer('val.txt') dataset = BufferDataset(buffer) # Get Training Loader validating_loader = DataLoader(dataset, batch_size=hp.batch_expand_size * hp.batch_size, shuffle=True, collate_fn=collate_fn_tensor, drop_last=False, num_workers=0) # Get Loss fastspeech_loss = DNNLoss().to(device) t_l = [] d_l = [] mel_l = [] mel_p_l = [] idx = 0 current_step = 0 x = [i for i, batchs in enumerate(validating_loader)] print(len(x)) for i, batchs in enumerate(validating_loader): # real batch start here for j, db in enumerate(batchs): print(len(batchs), len(db)) # Get Data id_ = db["name"] src_len = torch.from_numpy(db["src_len"]).long().to(device) mel_len = torch.from_numpy(db["mel_len"]).long().to(device) # Get Data character = db["text"].long().to(device) mel_target = db["mel_target"].float().to(device) duration = db["duration"].int().to(device) mel_pos = db["mel_pos"].long().to(device) src_pos = db["src_pos"].long().to(device) max_mel_len = db["mel_max_len"] print(duration.shape) # Forward mel_output, mel_postnet_output, duration_predictor_output = model( character, src_pos, mel_pos=mel_pos, mel_max_length=max_mel_len, length_target=duration) # Cal Loss mel_loss, mel_postnet_loss, duration_loss = fastspeech_loss( mel_output, mel_postnet_output, duration_predictor_output, mel_target, duration) total_loss = mel_loss + mel_postnet_loss + duration_loss t_l.append(total_loss.item()) d_l.append(duration_loss.item()) mel_l.append(mel_loss.item()) mel_p_l.append(mel_postnet_loss.item()) if vocoder is not None: # Run vocoding and plotting spectrogram only when the vocoder is defined for k in range(len(mel_target)): basename = id_[k] gt_length = mel_len[k] # out_length = out_mel_len[k] mel_target_torch = mel_target[k:k + 1, :gt_length].transpose( 1, 2).detach() # mel_target_ = mel_target[k, :gt_length].cpu( # ).transpose(0, 1).detach() mel_postnet_torch = mel_postnet_output[k:k + 1, :].transpose( 1, 2).detach() mel_postnet = mel_postnet_output[k, :].cpu().transpose( 0, 1).detach() if hp.vocoder == 'waveglow': utils.waveglow_infer( mel_target_torch, vocoder, os.path.join( hp.eval_path, 'ground-truth_{}_{}.wav'.format( basename, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, vocoder, os.path.join( hp.eval_path, 'eval_{}_{}.wav'.format(basename, hp.vocoder))) np.save( os.path.join(hp.eval_path, 'eval_{}_mel.npy'.format(basename)), mel_postnet.numpy()) idx += 1 current_step += 1 t_l = sum(t_l) / len(t_l) d_l = sum(d_l) / len(d_l) mel_l = sum(mel_l) / len(mel_l) mel_p_l = sum(mel_p_l) / len(mel_p_l) str1 = "FastSpeech2 Step {},".format(step) str2 = "Total Loss {},".format(t_l) str3 = "Duration Loss: {}".format(d_l) str4 = "Mel Loss: {}".format(mel_l) str5 = "Mel Postnet Loss: {}".format(mel_p_l) print("\n" + str1) print(str2) print(str3) print(str4) print(str5) with open(os.path.join(hp.log_path, "eval.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write(str4 + "\n") f_log.write(str5 + "\n") f_log.write("\n") return t_l, d_l, mel_l, mel_p_l
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("test.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=False, collate_fn=dataset.collate_fn, drop_last=False, num_workers=0) model = get_FastSpeech2(args.step, full_path=args.model_fs).to(device) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan(full_path=args.model_melgan) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'test')) test_logger = SummaryWriter(os.path.join(log_path, 'test')) # Init synthesis directory test_path = hp.test_path if not os.path.exists(test_path): os.makedirs(test_path) current_step = args.step findex = open(os.path.join(test_path, "index.tsv"), "w") # Testing print("Generate test audio") prefix = "" for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): print("Start batch", j) fids = data_of_batch["id"] # Get Data text = torch.from_numpy(data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy(data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len) for i in range(len(mel_len)): fid = fids[i] print("Generate audio for:", fid) length = mel_len[i].item() _mel_target_torch = mel_target[i, :length].detach().unsqueeze( 0).transpose(1, 2) _mel_target = mel_target[i, :length].detach().cpu().transpose( 0, 1) _mel_torch = mel_output[i, :length].detach().unsqueeze( 0).transpose(1, 2) _mel = mel_output[i, :length].detach().cpu().transpose(0, 1) _mel_postnet_torch = mel_postnet_output[ i, :length].detach().unsqueeze(0).transpose(1, 2) _mel_postnet = mel_postnet_output[ i, :length].detach().cpu().transpose(0, 1) fname = "{}{}_step_{}_gt_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel_target, os.path.join(hp.test_path, fname)) _write_index_line(findex, "Griffin Lim", "vocoder", fname, "", fid) fname = "{}{}_step_{}_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + GL", "tts", fname, "", fid) fname = "{}{}_step_{}_postnet_griffin_lim.wav".format( prefix, fid, current_step) Audio.tools.inv_mel_spec(_mel_postnet, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + PN + GL", "tts", fname, "", fid) if hp.vocoder == 'melgan': fname = '{}{}_step_{}_ground-truth_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_target_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "Melgan", "vocoder", fname, "", fid) fname = '{}{}_step_{}_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + Melgan", "tts", fname, "", fid) fname = '{}{}_step_{}_postnet_{}.wav'.format( prefix, fid, current_step, hp.vocoder) utils.melgan_infer(_mel_postnet_torch, melgan, os.path.join(hp.test_path, fname)) _write_index_line(findex, "FastSpeech2 + PN + Melgan", "tts", fname, "", fid) elif hp.vocoder == 'waveglow': utils.waveglow_infer( _mel_torch, waveglow, os.path.join( hp.test_path, 'step_{}_{}.wav'.format(current_step, hp.vocoder))) utils.waveglow_infer( _mel_postnet_torch, waveglow, os.path.join( hp.test_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( _mel_target_torch, waveglow, os.path.join( hp.test_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)))