def get_FastSpeech2(num): checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) n_spkers = torch.load( checkpoint_path)['model']['module.embed_speakers.weight'].shape[0] if hp.use_spk_embed: model = nn.DataParallel(FastSpeech2(True, n_spkers)) else: model = nn.DataParallel(FastSpeech2()) model.load_state_dict(torch.load(checkpoint_path)['model']) model.requires_grad = False model.eval() return model
def get_FastSpeech2(num): checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) model = nn.DataParallel(FastSpeech2()) model.load_state_dict(torch.load(checkpoint_path)['model']) model.requires_grad = False model.eval() return model
def get_FastSpeech2(num, full_path=None): if full_path: checkpoint_path = full_path else: checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) model = nn.DataParallel(FastSpeech2()) model.load_state_dict( torch.load(checkpoint_path, map_location=device)['model']) model.requires_grad = False model.eval() return model
def get_FastSpeech2(num): checkpoint_path = os.path.join(hp.checkpoint_path, "checkpoint_{}.pth.tar".format(num)) # model = FastSpeech2() model = nn.DataParallel(FastSpeech2()) model_data = torch.load(checkpoint_path)['model'] # keys=model_data.keys() # model_data_={} # for key in keys: # model_data_[key[7:]]=model_data[key] # print(model_data_.keys()) model.load_state_dict(model_data) model.requires_grad = False model.eval() return model
def get_FastSpeech2(model_path, with_hanzi=True): #checkpoint_path = os.path.join( # hp.checkpoint_path,'no_ch_goo', "checkpoint_{}.pth.tar".format(num)) #checkpoint_path = '/home/ranch/code/FastSpeech2/ckpt/baker/checkpoint_380000.pth.tar' print('loading model from', model_path) model = FastSpeech2(py_vocab_size, hz_vocab_size) sd = torch.load(model_path, map_location='cpu') if 'model' in sd.keys(): #checkpoint file sd = sd[ 'model'] # using only the model part(rather than the optim part) model.load_state_dict(sd) # model.load_state_dict(torch.load(best_model)) model.requires_grad = False model.eval() return model
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model if hp.use_spk_embed: n_pkers = len(dataset.spk_table.keys()) model = nn.DataParallel(FastSpeech2(n_spkers=n_pkers)).to(device) else: model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() #melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Init evaluation directory eval_path = hp.eval_path if not os.path.exists(eval_path): os.makedirs(eval_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 print("step : {}".format(current_step), end='\r', flush=True) ### Get Data ### if hp.use_spk_embed: spk_ids = torch.tensor(data_of_batch["spk_ids"]).to( torch.int64).to(device) else: spk_ids = None text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) ### Forward ### mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len, spk_ids) ### Cal Loss ### mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + 0.01 * f_loss + 0.1 * e_loss t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") ## Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue ### Update weights ### nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() ### Print ### if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) ### Save model ### if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) ### Synth ### if current_step % hp.synth_step == 0: length = mel_len[0].item() print("step: {} , length {}, {}".format( current_step, length, mel_len)) mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) spk_id = dataset.inv_spk_table[int(spk_ids[0])] if hp.vocoder == 'melgan': vocoder.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) elif hp.vocoder == 'waveglow': vocoder.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_{}.wav'.format( current_step, hp.vocoder))) vocoder.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_postnet_{}.wav'.format( current_step, spk_id, hp.vocoder))) vocoder.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_spk_{}_ground-truth_{}.wav'.format( current_step, spk_id, hp.vocoder))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) ### Evaluation ### if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() ### Time ### end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device # device = torch.device('cuda'if torch.cuda.is_available()else 'cpu') device = 'cuda' # Get dataset dataset = Dataset("train.txt") loader = DataLoaderX(dataset, batch_size=hp.batch_size * 4, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=8) # Define model model = nn.DataParallel(FastSpeech2()).to(device) # model = FastSpeech2().to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder if hp.vocoder == 'melgan': melgan = utils.get_melgan() melgan.to(device) elif hp.vocoder == 'waveglow': waveglow = utils.get_waveglow() waveglow.to(device) # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) current_time = time.strftime("%Y-%m-%dT%H:%M", time.localtime()) train_logger = SummaryWriter(log_dir='log/train/' + current_time) val_logger = SummaryWriter(log_dir='log/validation/' + current_time) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() current_step0 = 0 # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * len(batchs) + j + args.restore_step + \ epoch * len(loader)*len(batchs) + 1 # Get Data condition = torch.from_numpy( data_of_batch["condition"]).long().to(device) mel_refer = torch.from_numpy( data_of_batch["mel_refer"]).float().to(device) if hp.vocoder == 'WORLD': ap_target = torch.from_numpy( data_of_batch["ap_target"]).float().to(device) sp_target = torch.from_numpy( data_of_batch["sp_target"]).float().to(device) else: mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) #print(D,log_D) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) if hp.vocoder == 'WORLD': # print(condition.shape,mel_refer.shape, src_len.shape, mel_len.shape, D.shape, f0.shape, energy.shape, max_src_len.shape, max_mel_len.shape) ap_output, sp_output, sp_postnet_output, log_duration_output, f0_output, energy_output, src_mask, ap_mask, sp_mask, variance_adaptor_output, decoder_output = model( condition, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) ap_loss, sp_loss, sp_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, D, f0_output, f0, energy_output, energy, ap_output=ap_output, sp_output=sp_output, sp_postnet_output=sp_postnet_output, ap_target=ap_target, sp_target=sp_target, src_mask=src_mask, ap_mask=ap_mask, sp_mask=sp_mask) total_loss = ap_loss + sp_loss + sp_postnet_loss + d_loss + f_loss + e_loss else: mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( condition, mel_refer, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output=mel_output, mel_postnet_output=mel_postnet_output, mel_target=mel_target, src_mask=~src_mask, mel_mask=~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() if hp.vocoder == 'WORLD': ap_l = ap_loss.item() sp_l = sp_loss.item() sp_p_l = sp_postnet_loss.item() else: m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() # with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: # f_total_loss.write(str(t_l)+"\n") # with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: # f_mel_loss.write(str(m_l)+"\n") # with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: # f_mel_postnet_loss.write(str(m_p_l)+"\n") # with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: # f_d_loss.write(str(d_l)+"\n") # with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: # f_f_loss.write(str(f_l)+"\n") # with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: # f_e_loss.write(str(e_l)+"\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch[{}/{}],Step[{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) if hp.vocoder == 'WORLD': str2 = "Loss:{:.4f},ap:{:.4f},sp:{:.4f},spPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, ap_l, sp_l, sp_p_l, d_l, f_l, e_l) else: str2 = "Loss:{:.4f},Mel:{:.4f},MelPN:{:.4f},Dur:{:.4f},F0:{:.4f},Energy:{:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "T:{:.1f}s,ETA:{:.1f}s.".format( (Now - Start) / (current_step - current_step0), (total_step - current_step) * np.mean(Time)) print("" + str1 + str2 + str3 + '') # with open(os.path.join(log_path, "log.txt"), "a") as f_log: # f_log.write(str1 + "\n") # f_log.write(str2 + "\n") # f_log.write(str3 + "\n") # f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) if hp.vocoder == 'WORLD': train_logger.add_scalar('Loss/ap_loss', ap_l, current_step) train_logger.add_scalar('Loss/sp_loss', sp_l, current_step) train_logger.add_scalar('Loss/sp_postnet_loss', sp_p_l, current_step) else: train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0 or current_step == 20: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0 or current_step == 5: length = mel_len[0].item() if hp.vocoder == 'WORLD': ap_target_torch = ap_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) ap_torch = ap_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_target_torch = sp_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) sp_torch = sp_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) sp_postnet_torch = sp_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) else: mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose( 0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) # Audio.tools.inv_mel_spec(mel, os.path.join(synth_path, "step_{}_griffin_lim.wav".format(current_step))) # Audio.tools.inv_mel_spec(mel_postnet, os.path.join(synth_path, "step_{}_postnet_griffin_lim.wav".format(current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() if hp.vocoder == 'melgan': utils.melgan_infer( mel_torch, melgan, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_postnet_torch, melgan, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.melgan_infer( mel_target_torch, melgan, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'waveglow': utils.waveglow_infer( mel_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_postnet_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder))) utils.waveglow_infer( mel_target_torch, waveglow, os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder))) elif hp.vocoder == 'WORLD': # ap=np.swapaxes(ap,0,1) # sp=np.swapaxes(sp,0,1) wav = utils.world_infer( np.swapaxes(ap_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_postnet_torch[0].cpu().numpy(), 0, 1), f0_output) sf.write( os.path.join( hp.synth_path, 'step_{}_postnet_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) wav = utils.world_infer( np.swapaxes(ap_target_torch[0].cpu().numpy(), 0, 1), np.swapaxes(sp_target_torch[0].cpu().numpy(), 0, 1), f0) sf.write( os.path.join( hp.synth_path, 'step_{}_ground-truth_{}.wav'.format( current_step, hp.vocoder)), wav, 32000) utils.plot_data([ (sp_postnet_torch[0].cpu().numpy(), f0_output, energy_output), (sp_target_torch[0].cpu().numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) plt.matshow(sp_postnet_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'sp_postnet_{}.png'.format(current_step))) plt.matshow(ap_torch[0].cpu().numpy()) plt.savefig( os.path.join(synth_path, 'ap_{}.png'.format(current_step))) plt.matshow( variance_adaptor_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'va_{}.png'.format(current_step))) # plt.matshow(decoder_output[0].detach().cpu().numpy()) # plt.savefig(os.path.join(synth_path, 'encoder_{}.png'.format(current_step))) plt.cla() fout = open( os.path.join(synth_path, 'D_{}.txt'.format(current_step)), 'w') fout.write( str(log_duration_output[0].detach().cpu().numpy()) + '\n') fout.write(str(D[0].detach().cpu().numpy()) + '\n') fout.write( str(condition[0, :, 2].detach().cpu().numpy()) + '\n') fout.close() # if current_step % hp.eval_step == 0 or current_step==20: # model.eval() # with torch.no_grad(): # if hp.vocoder=='WORLD': # d_l, f_l, e_l, ap_l, sp_l, sp_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + ap_l + sp_l + sp_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/ap_loss', ap_l, current_step) # val_logger.add_scalar('valLoss/sp_loss', sp_l, current_step) # val_logger.add_scalar('valLoss/sp_postnet_loss', sp_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # else: # d_l, f_l, e_l, m_l, m_p_l = evaluate(model, current_step) # t_l = d_l + f_l + e_l + m_l + m_p_l # val_logger.add_scalar('valLoss/total_loss', t_l, current_step) # val_logger.add_scalar('valLoss/mel_loss', m_l, current_step) # val_logger.add_scalar('valLoss/mel_postnet_loss', m_p_l, current_step) # val_logger.add_scalar('valLoss/duration_loss', d_l, current_step) # val_logger.add_scalar('valLoss/F0_loss', f_l, current_step) # val_logger.add_scalar('valLoss/energy_loss', e_l, current_step) # model.train() # if current_step%10==0: # print(energy_output[0],energy[0]) end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # read params mean_mel, std_mel = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "mel_stat.npy")), dtype=torch.float).to(device) mean_f0, std_f0 = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "f0_stat.npy")), dtype=torch.float).to(device) mean_energy, std_energy = torch.tensor(np.load( os.path.join(hp.preprocessed_path, "energy_stat.npy")), dtype=torch.float).to(device) mean_mel, std_mel = mean_mel.reshape(1, -1), std_mel.reshape(1, -1) mean_f0, std_f0 = mean_f0.reshape(1, -1), std_f0.reshape(1, -1) mean_energy, std_energy = mean_energy.reshape(1, -1), std_energy.reshape( 1, -1) # Load vocoder if hp.vocoder == 'vocgan': vocoder = utils.get_vocgan(ckpt_path=hp.vocoder_pretrained_model_path) vocoder.to(device) else: vocoder = None # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) os.makedirs(os.path.join(log_path, 'train')) os.makedirs(os.path.join(log_path, 'validation')) train_logger = SummaryWriter(os.path.join(log_path, 'train')) val_logger = SummaryWriter(os.path.join(log_path, 'validation')) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).long().to(device) log_D = torch.from_numpy( data_of_batch["log_D"]).float().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_src_len = np.max(data_of_batch["src_len"]).astype(np.int32) max_mel_len = np.max(data_of_batch["mel_len"]).astype(np.int32) # Forward mel_output, mel_postnet_output, log_duration_output, f0_output, energy_output, src_mask, mel_mask, _ = model( text, src_len, mel_len, D, f0, energy, max_src_len, max_mel_len) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( log_duration_output, log_D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, ~src_mask, ~mel_mask) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss = total_loss / hp.acc_steps total_loss.backward() if current_step % hp.acc_steps != 0: continue # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() scheduled_optim.zero_grad() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") train_logger.add_scalar('Loss/total_loss', t_l, current_step) train_logger.add_scalar('Loss/mel_loss', m_l, current_step) train_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) train_logger.add_scalar('Loss/duration_loss', d_l, current_step) train_logger.add_scalar('Loss/F0_loss', f_l, current_step) train_logger.add_scalar('Loss/energy_loss', e_l, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step, vocoder) t_l = d_l + f_l + e_l + m_l + m_p_l val_logger.add_scalar('Loss/total_loss', t_l, current_step) val_logger.add_scalar('Loss/mel_loss', m_l, current_step) val_logger.add_scalar('Loss/mel_postnet_loss', m_p_l, current_step) val_logger.add_scalar('Loss/duration_loss', d_l, current_step) val_logger.add_scalar('Loss/F0_loss', f_l, current_step) val_logger.add_scalar('Loss/energy_loss', e_l, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
vocab_file_hz=None) py_vocab_size = len(dataset.py_vocab) hz_vocab_size = None # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=False, num_workers=8) # Define model model = FastSpeech2(py_vocab_size, hz_vocab_size).to(device) num_param = utils.get_param_num(model) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), lr=hp.start_lr, betas=hp.betas, eps=hp.eps, weight_decay=0) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try:
def main(args): torch.manual_seed(0) # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Get dataset dataset = Dataset("train.txt") loader = DataLoader(dataset, batch_size=hp.batch_size**2, shuffle=True, collate_fn=dataset.collate_fn, drop_last=True, num_workers=0) # Define model model = nn.DataParallel(FastSpeech2()).to(device) print("Model Has Been Defined") num_param = utils.get_param_num(model) print('Number of FastSpeech2 Parameters:', num_param) # Optimizer and loss optimizer = torch.optim.Adam(model.parameters(), betas=hp.betas, eps=hp.eps, weight_decay=hp.weight_decay) scheduled_optim = ScheduledOptim(optimizer, hp.decoder_hidden, hp.n_warm_up_step, args.restore_step) Loss = FastSpeech2Loss().to(device) print("Optimizer and Loss Function Defined.") # Load checkpoint if exists checkpoint_path = os.path.join(hp.checkpoint_path) try: checkpoint = torch.load( os.path.join(checkpoint_path, 'checkpoint_{}.pth.tar'.format(args.restore_step))) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) except: print("\n---Start New Training---\n") if not os.path.exists(checkpoint_path): os.makedirs(checkpoint_path) # Load vocoder wave_glow = utils.get_WaveGlow() # Init logger log_path = hp.log_path if not os.path.exists(log_path): os.makedirs(log_path) logger = SummaryWriter(log_path) # Init synthesis directory synth_path = hp.synth_path if not os.path.exists(synth_path): os.makedirs(synth_path) # Define Some Information Time = np.array([]) Start = time.perf_counter() # Training model = model.train() for epoch in range(hp.epochs): # Get Training Loader total_step = hp.epochs * len(loader) * hp.batch_size for i, batchs in enumerate(loader): for j, data_of_batch in enumerate(batchs): start_time = time.perf_counter() current_step = i * hp.batch_size + j + args.restore_step + epoch * len( loader) * hp.batch_size + 1 # Init scheduled_optim.zero_grad() # Get Data text = torch.from_numpy( data_of_batch["text"]).long().to(device) mel_target = torch.from_numpy( data_of_batch["mel_target"]).float().to(device) D = torch.from_numpy(data_of_batch["D"]).int().to(device) f0 = torch.from_numpy(data_of_batch["f0"]).float().to(device) energy = torch.from_numpy( data_of_batch["energy"]).float().to(device) mel_pos = torch.from_numpy( data_of_batch["mel_pos"]).long().to(device) src_pos = torch.from_numpy( data_of_batch["src_pos"]).long().to(device) src_len = torch.from_numpy( data_of_batch["src_len"]).long().to(device) mel_len = torch.from_numpy( data_of_batch["mel_len"]).long().to(device) max_len = max(data_of_batch["mel_len"]).astype(np.int16) # Forward mel_output, mel_postnet_output, duration_output, f0_output, energy_output = model( text, src_pos, mel_pos, max_len, D, f0, energy) # Cal Loss mel_loss, mel_postnet_loss, d_loss, f_loss, e_loss = Loss( duration_output, D, f0_output, f0, energy_output, energy, mel_output, mel_postnet_output, mel_target, src_len, mel_len) total_loss = mel_loss + mel_postnet_loss + d_loss + f_loss + e_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_l = d_loss.item() f_l = f_loss.item() e_l = e_loss.item() with open(os.path.join(log_path, "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l) + "\n") with open(os.path.join(log_path, "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l) + "\n") with open(os.path.join(log_path, "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l) + "\n") with open(os.path.join(log_path, "duration_loss.txt"), "a") as f_d_loss: f_d_loss.write(str(d_l) + "\n") with open(os.path.join(log_path, "f0_loss.txt"), "a") as f_f_loss: f_f_loss.write(str(f_l) + "\n") with open(os.path.join(log_path, "energy_loss.txt"), "a") as f_e_loss: f_e_loss.write(str(e_l) + "\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.perf_counter() str1 = "Epoch [{}/{}], Step [{}/{}]:".format( epoch + 1, hp.epochs, current_step, total_step) str2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f}, Duration Loss: {:.4f}, F0 Loss: {:.4f}, Energy Loss: {:.4f};".format( t_l, m_l, m_p_l, d_l, f_l, e_l) str3 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now - Start), (total_step - current_step) * np.mean(Time)) print("\n" + str1) print(str2) print(str3) with open(os.path.join(log_path, "log.txt"), "a") as f_log: f_log.write(str1 + "\n") f_log.write(str2 + "\n") f_log.write(str3 + "\n") f_log.write("\n") logger.add_scalars('Loss/total_loss', {'training': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'training': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'training': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'training': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'training': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'training': e_l}, current_step) if current_step % hp.save_step == 0: torch.save( { 'model': model.state_dict(), 'optimizer': optimizer.state_dict() }, os.path.join( checkpoint_path, 'checkpoint_{}.pth.tar'.format(current_step))) print("save model at step {} ...".format(current_step)) if current_step % hp.synth_step == 0: length = mel_len[0].item() mel_target_torch = mel_target[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_target = mel_target[ 0, :length].detach().cpu().transpose(0, 1) mel_torch = mel_output[0, :length].detach().unsqueeze( 0).transpose(1, 2) mel = mel_output[0, :length].detach().cpu().transpose(0, 1) mel_postnet_torch = mel_postnet_output[ 0, :length].detach().unsqueeze(0).transpose(1, 2) mel_postnet = mel_postnet_output[ 0, :length].detach().cpu().transpose(0, 1) Audio.tools.inv_mel_spec( mel, os.path.join( synth_path, "step_{}_griffin_lim.wav".format(current_step))) Audio.tools.inv_mel_spec( mel_postnet, os.path.join( synth_path, "step_{}_postnet_griffin_lim.wav".format( current_step))) waveglow.inference.inference( mel_torch, wave_glow, os.path.join( synth_path, "step_{}_waveglow.wav".format(current_step))) waveglow.inference.inference( mel_postnet_torch, wave_glow, os.path.join( synth_path, "step_{}_postnet_waveglow.wav".format( current_step))) waveglow.inference.inference( mel_target_torch, wave_glow, os.path.join( synth_path, "step_{}_ground-truth_waveglow.wav".format( current_step))) f0 = f0[0, :length].detach().cpu().numpy() energy = energy[0, :length].detach().cpu().numpy() f0_output = f0_output[0, :length].detach().cpu().numpy() energy_output = energy_output[ 0, :length].detach().cpu().numpy() utils.plot_data([ (mel_postnet.numpy(), f0_output, energy_output), (mel_target.numpy(), f0, energy) ], ['Synthetized Spectrogram', 'Ground-Truth Spectrogram'], filename=os.path.join( synth_path, 'step_{}.png'.format(current_step))) if current_step % hp.eval_step == 0: model.eval() with torch.no_grad(): d_l, f_l, e_l, m_l, m_p_l = evaluate( model, current_step) t_l = d_l + f_l + e_l + m_l + m_p_l logger.add_scalars('Loss/total_loss', {'validation': t_l}, current_step) logger.add_scalars('Loss/mel_loss', {'validation': m_l}, current_step) logger.add_scalars('Loss/mel_postnet_loss', {'validation': m_p_l}, current_step) logger.add_scalars('Loss/duration_loss', {'validation': d_l}, current_step) logger.add_scalars('Loss/F0_loss', {'validation': f_l}, current_step) logger.add_scalars('Loss/energy_loss', {'validation': e_l}, current_step) model.train() end_time = time.perf_counter() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete(Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)