def collate_fn(batch): texts = [d['text'] for d in batch] mels = [d['mel'] for d in batch] texts, pos_padded = pad_text(texts) # print(pos_padded) alignment_target = get_alignment(texts, pos_padded) # mels, gate_target, tgt_sep, tgt_pos = pad_mel(mels) mels = pad_mel(mels) return {"texts": texts, "pos": pos_padded, "mels": mels, "alignment": alignment_target}
def main(args): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Define model model = nn.DataParallel(FastSpeech()).to(device) #tacotron2 = get_tacotron2() print("FastSpeech and Tacotron2 Have Been Defined") num_param = sum(param.numel() for param in model.parameters()) print('Number of FastSpeech Parameters:', num_param) # Get dataset dataset = FastSpeechDataset() # Optimizer and loss optimizer = torch.optim.Adam( model.parameters(), betas=(0.9, 0.98), eps=1e-9) scheduled_optim = ScheduledOptim(optimizer, hp.word_vec_dim, hp.n_warm_up_step, args.restore_step) fastspeech_loss = FastSpeechLoss().to(device) print("Defined Optimizer and Loss Function.") # Get training loader print("Get Training Loader") training_loader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=cpu_count()) try: checkpoint = torch.load(os.path.join( hp.checkpoint_path, 'checkpoint_%d.pth.tar' % args.restore_step)) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) print("\n------Model Restored at Step %d------\n" % args.restore_step) except: print("\n------Start New Training------\n") if not os.path.exists(hp.checkpoint_path): os.mkdir(hp.checkpoint_path) # Init logger if not os.path.exists(hp.logger_path): os.mkdir(hp.logger_path) # Training model = model.train() total_step = hp.epochs * len(training_loader) Time = np.array(list()) Start = time.clock() summary = SummaryWriter() for epoch in range(hp.epochs): for i, data_of_batch in enumerate(training_loader): start_time = time.clock() current_step = i + args.restore_step + \ epoch * len(training_loader) + 1 # Init scheduled_optim.zero_grad() if not hp.pre_target: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = get_alignment( src_seq, tacotron2).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) else: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] alignment_target = data_of_batch["alignment"] # print(alignment_target) # print(alignment_target.shape) # print(mel_tgt.shape) # print(src_seq.shape) # print(src_seq) src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = torch.from_numpy( alignment_target).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) # print(alignment_target.shape) # Forward mel_output, mel_output_postnet, duration_predictor_output = model( src_seq, src_pos, mel_max_length=mel_max_len, length_target=alignment_target) # Cal Loss mel_loss, mel_postnet_loss, duration_predictor_loss = fastspeech_loss( mel_output, mel_output_postnet, duration_predictor_output, mel_tgt, alignment_target) total_loss = mel_loss + mel_postnet_loss + duration_predictor_loss # Logger t_l = total_loss.item() m_l = mel_loss.item() m_p_l = mel_postnet_loss.item() d_p_l = duration_predictor_loss.item() with open(os.path.join("logger", "total_loss.txt"), "a") as f_total_loss: f_total_loss.write(str(t_l)+"\n") with open(os.path.join("logger", "mel_loss.txt"), "a") as f_mel_loss: f_mel_loss.write(str(m_l)+"\n") with open(os.path.join("logger", "mel_postnet_loss.txt"), "a") as f_mel_postnet_loss: f_mel_postnet_loss.write(str(m_p_l)+"\n") with open(os.path.join("logger", "duration_predictor_loss.txt"), "a") as f_d_p_loss: f_d_p_loss.write(str(d_p_l)+"\n") # Backward total_loss.backward() # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) # Update weights if args.frozen_learning_rate: scheduled_optim.step_and_update_lr_frozen( args.learning_rate_frozen) else: scheduled_optim.step_and_update_lr() # Print if current_step % hp.log_step == 0: Now = time.clock() str1 = "Epoch [{}/{}], Step [{}/{}], Mel Loss: {:.4f}, Mel PostNet Loss: {:.4f};".format( epoch+1, hp.epochs, current_step, total_step, mel_loss.item(), mel_postnet_loss.item()) str2 = "Duration Predictor Loss: {:.4f}, Total Loss: {:.4f}.".format( duration_predictor_loss.item(), total_loss.item()) str3 = "Current Learning Rate is {:.6f}.".format( scheduled_optim.get_learning_rate()) str4 = "Time Used: {:.3f}s, Estimated Time Remaining: {:.3f}s.".format( (Now-Start), (total_step-current_step)*np.mean(Time)) print("\n" + str1) print(str2) print(str3) print(str4) with open(os.path.join("logger", "logger.txt"), "a") as f_logger: f_logger.write(str1 + "\n") f_logger.write(str2 + "\n") f_logger.write(str3 + "\n") f_logger.write(str4 + "\n") f_logger.write("\n") summary.add_scalar('loss', total_loss.item(), current_step) if current_step % hp.save_step == 0: torch.save({'model': model.state_dict(), 'optimizer': optimizer.state_dict( )}, os.path.join(hp.checkpoint_path, 'checkpoint_%d.pth.tar' % current_step)) print("save model at step %d ..." % current_step) end_time = time.clock() Time = np.append(Time, end_time - start_time) if len(Time) == hp.clear_Time: temp_value = np.mean(Time) Time = np.delete( Time, [i for i in range(len(Time))], axis=None) Time = np.append(Time, temp_value)
def train(num_gpus, rank, group_name, output_directory, log_directory, checkpoint_path): # Get device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.manual_seed(hp.seed) torch.cuda.manual_seed(hp.seed) #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: init_distributed(rank, num_gpus, group_name, **dist_config) #=====END: ADDED FOR DISTRIBUTED====== criterion = WaveGlowLoss(hp.sigma) model = WaveGlow().cuda() #=====START: ADDED FOR DISTRIBUTED====== if num_gpus > 1: model = apply_gradient_allreduce(model) #=====END: ADDED FOR DISTRIBUTED====== learning_rate = hp.learning_rate optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) if hp.fp16_run: from apex import amp model, optimizer = amp.initialize(model, optimizer, opt_level='O1') # Load checkpoint if one exists iteration = 0 if checkpoint_path: model, optimizer, iteration = load_checkpoint(checkpoint_path, model, optimizer) iteration += 1 # next iteration is iteration + 1 # Get dataset dataset = FastSpeechDataset() # Get training loader print("Get Training Loader") training_loader = DataLoader(dataset, batch_size=hp.batch_size, shuffle=True, collate_fn=collate_fn, drop_last=True, num_workers=cpu_count()) if rank == 0: if not os.path.isdir(output_directory): os.makedirs(output_directory) os.chmod(output_directory, 0o775) print("output directory", output_directory) if hp.with_tensorboard and rank == 0: logger = prepare_directories_and_logger(output_directory, log_directory) model = model.train() epoch_offset = max(0, int(iteration / len(training_loader))) beta = hp.batch_size print("Total Epochs: {}".format(hp.epochs)) print("Batch Size: {}".format(hp.batch_size)) # ================ MAIN TRAINNIG LOOP! =================== for epoch in range(epoch_offset, hp.epochs): print("Epoch: {}".format(epoch)) for i, data_of_batch in enumerate(training_loader): model.zero_grad() if not hp.pre_target: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = get_alignment(src_seq, tacotron2).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) else: # Prepare Data src_seq = data_of_batch["texts"] src_pos = data_of_batch["pos"] mel_tgt = data_of_batch["mels"] alignment_target = data_of_batch["alignment"] src_seq = torch.from_numpy(src_seq).long().to(device) src_pos = torch.from_numpy(src_pos).long().to(device) mel_tgt = torch.from_numpy(mel_tgt).float().to(device) alignment_target = torch.from_numpy( alignment_target).float().to(device) # For Data Parallel mel_max_len = mel_tgt.size(1) outputs = model(src_seq, src_pos, mel_tgt, mel_max_len, alignment_target) _, _, _, duration_predictor = outputs mel_tgt = mel_tgt.transpose(1, 2) max_like, dur_loss = criterion(outputs, alignment_target, mel_tgt) if beta > 1 and iteration % 10000 == 0: beta = beta // 2 loss = max_like + dur_loss if num_gpus > 1: reduced_loss = reduce_tensor(loss.data, num_gpus).item() else: reduced_loss = loss.item() if hp.fp16_run: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() #grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), hp.grad_clip_thresh) optimizer.step() print("{}:\t{:.9f}".format(iteration, reduced_loss)) if hp.with_tensorboard and rank == 0: logger.log_training(reduced_loss, dur_loss, learning_rate, iteration) if (iteration % hp.save_step == 0): if rank == 0: # logger.log_alignment(model, mel_predict, mel_tgt, iteration) checkpoint_path = "{}/TTSglow_{}".format( output_directory, iteration) save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path) iteration += 1