def main(args, configs): print("Prepare training ...") preprocess_config, model_config, train_config = configs # Get dataset dataset = Dataset("train.txt", preprocess_config, train_config, sort=True, drop_last=True) batch_size = train_config["optimizer"]["batch_size"] group_size = 4 # Set this larger than 1 to enable sorting in Dataset assert batch_size * group_size < len(dataset) loader = DataLoader( dataset, batch_size=batch_size * group_size, shuffle=True, collate_fn=dataset.collate_fn, ) with open( os.path.join(preprocess_config["path"]["preprocessed_path"], "stats.json")) as f: stats = json.load(f) mel_stats = stats["mel"] # Prepare model model, optimizer = get_model(args, configs, device, train=True) model = nn.DataParallel(model) num_param = get_param_num(model) Loss = ParallelTacotron2Loss(model_config, train_config).to(device) print("Number of Parallel Tacotron 2 Parameters:", num_param) # Load vocoder vocoder = get_vocoder(model_config, device) # Init logger for p in train_config["path"].values(): os.makedirs(p, exist_ok=True) train_log_path = os.path.join(train_config["path"]["log_path"], "train") val_log_path = os.path.join(train_config["path"]["log_path"], "val") os.makedirs(train_log_path, exist_ok=True) os.makedirs(val_log_path, exist_ok=True) train_logger = SummaryWriter(train_log_path) val_logger = SummaryWriter(val_log_path) # Training step = args.restore_step + 1 epoch = 1 grad_acc_step = train_config["optimizer"]["grad_acc_step"] grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] total_step = train_config["step"]["total_step"] log_step = train_config["step"]["log_step"] save_step = train_config["step"]["save_step"] synth_step = train_config["step"]["synth_step"] val_step = train_config["step"]["val_step"] outer_bar = tqdm(total=total_step, desc="Training", position=0) outer_bar.n = args.restore_step outer_bar.update() # with torch.autograd.detect_anomaly(): while True: inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) for batchs in loader: for batch in batchs: batch = to_device(batch, device, mel_stats) # Forward output = model(*(batch[2:])) # Cal Loss losses = Loss(batch, output, step) total_loss = losses[0] # Backward total_loss = total_loss / grad_acc_step total_loss.backward() if step % grad_acc_step == 0: # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) # Update weights optimizer.step_and_update_lr() optimizer.zero_grad() if step % log_step == 0: losses = [l.item() for l in losses] message1 = "Step {}/{}, ".format(step, total_step) message2 = "Total Loss: {:.4f}, Mel Loss: {:.4f}, Duration Loss: {:.4f}, KL Loss: {:.4f}".format( *losses) with open(os.path.join(train_log_path, "log.txt"), "a") as f: f.write(message1 + message2 + "\n") outer_bar.write(message1 + message2) log(train_logger, step, losses=losses) if step % synth_step == 0: fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( batch, output, vocoder, model_config, preprocess_config, mel_stats, ) log( train_logger, fig=fig, tag="Training/step_{}_{}".format(step, tag), ) sampling_rate = preprocess_config["preprocessing"][ "audio"]["sampling_rate"] log( train_logger, audio=wav_reconstruction, sampling_rate=sampling_rate, tag="Training/step_{}_{}_reconstructed".format( step, tag), ) log( train_logger, audio=wav_prediction, sampling_rate=sampling_rate, tag="Training/step_{}_{}_synthesized".format( step, tag), ) if step % val_step == 0: model.eval() message = evaluate(model, step, configs, val_logger, vocoder, len(losses), mel_stats) with open(os.path.join(val_log_path, "log.txt"), "a") as f: f.write(message + "\n") outer_bar.write(message) model.train() if step % save_step == 0: torch.save( { "model": model.module.state_dict(), "optimizer": optimizer._optimizer.state_dict(), }, os.path.join( train_config["path"]["ckpt_path"], "{}.pth.tar".format(step), ), ) if step == total_step: quit() step += 1 outer_bar.update(1) inner_bar.update(1) epoch += 1
def evaluate(model, step, configs, logger=None, vocoder=None): preprocess_config, model_config, train_config = configs # Get dataset dataset = Dataset( "val.txt", preprocess_config, train_config, sort=False, drop_last=False ) batch_size = train_config["optimizer"]["batch_size"] loader = DataLoader( dataset, batch_size=batch_size, shuffle=False, collate_fn=dataset.collate_fn, ) # Get loss function Loss = ParallelTacotron2Loss(model_config, train_config).to(device) # Evaluation loss_sums = [0 for _ in range(6)] for batchs in loader: for batch in batchs: batch = to_device(batch, device) with torch.no_grad(): # Forward output = model(*(batch[2:])) # Cal Loss losses = Loss(batch, output, step) for i in range(len(losses)): loss_sums[i] += losses[i].item() * len(batch[0]) loss_means = [loss_sum / len(dataset) for loss_sum in loss_sums] message = "Validation Step {}, Total Loss: {:.4f}, Mel Loss: {:.4f}, Duration Loss: {:.4f}, KL Loss: {:.4f}, Attention Loss: {:.4f}".format( *([step] + [l for l in loss_means]) ) if logger is not None: fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( batch, output, vocoder, model_config, preprocess_config, ) log(logger, step, losses=loss_means) log( logger, fig=fig, tag="Validation/step_{}_{}".format(step, tag), ) sampling_rate = preprocess_config["preprocessing"]["audio"]["sampling_rate"] log( logger, audio=wav_reconstruction, sampling_rate=sampling_rate, tag="Validation/step_{}_{}_reconstructed".format(step, tag), ) log( logger, audio=wav_prediction, sampling_rate=sampling_rate, tag="Validation/step_{}_{}_synthesized".format(step, tag), ) return message
def main(args, configs): print("Prepare training ...") preprocess_config, model_config, train_config = configs # Get dataset dataset = Dataset("train.txt", preprocess_config, train_config, model_config, sort=True, drop_last=True) batch_size = train_config["optimizer"]["batch_size"] group_size = 4 # Set this larger than 1 to enable sorting in Dataset assert batch_size * group_size < len(dataset) loader = DataLoader( dataset, batch_size=batch_size * group_size, shuffle=True, collate_fn=dataset.collate_fn, ) # Prepare model model, optimizer = get_model(args, configs, device, train=True) model = nn.DataParallel(model) num_param = get_param_num(model) Loss = FastSpeech2Loss(preprocess_config, model_config).to(device) print("Number of FastSpeech2 Parameters:", num_param) # Load checkpoint if exists if args.restore_path is not None and os.path.isfile(args.restore_path): checkpoint = torch.load(args.restore_path) pretrained_dict = checkpoint['model'] if not any(key.startswith('module.') for key in pretrained_dict): pretrained_dict = { 'module.' + k: v for k, v in pretrained_dict.items() } dem1 = 0 dem2 = 0 model_dict = model.state_dict() for k, v in pretrained_dict.items(): if k in model_dict and v.size() == model_dict[k].size(): # print('Load weight in ', k) dem1 += 1 else: print(f'Module {k} is not same size') dem2 += 1 dem2 += dem1 print(f'### Load {dem1}/{dem2} modules') # 1. filter out unnecessary keys pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict and v.size() == model_dict[k].size() } # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) # model.load_state_dict(checkpoint['model']) # optimizer.load_state_dict(checkpoint['optimizer']) print("\n---Model Restored at Step {}---\n".format(args.restore_step)) # Load vocoder vocoder = get_vocoder(model_config, device) # Init logger for p in train_config["path"].values(): os.makedirs(p, exist_ok=True) train_log_path = os.path.join(train_config["path"]["log_path"], "train") val_log_path = os.path.join(train_config["path"]["log_path"], "val") os.makedirs(train_log_path, exist_ok=True) os.makedirs(val_log_path, exist_ok=True) train_logger = SummaryWriter(train_log_path) val_logger = SummaryWriter(val_log_path) # Training step = args.restore_step + 1 epoch = 1 grad_acc_step = train_config["optimizer"]["grad_acc_step"] grad_clip_thresh = train_config["optimizer"]["grad_clip_thresh"] total_step = train_config["step"]["total_step"] log_step = train_config["step"]["log_step"] save_step = train_config["step"]["save_step"] synth_step = train_config["step"]["synth_step"] val_step = train_config["step"]["val_step"] outer_bar = tqdm(total=total_step, desc="Training", position=0) outer_bar.n = args.restore_step outer_bar.update() while True: inner_bar = tqdm(total=len(loader), desc="Epoch {}".format(epoch), position=1) for batchs in loader: for batch in batchs: batch = to_device(batch, device) # Forward output = model(*(batch[2:])) # Cal Loss losses = Loss(batch, output) total_loss = losses[0] # Backward total_loss = total_loss / grad_acc_step total_loss.backward() if step % grad_acc_step == 0: # Clipping gradients to avoid gradient explosion nn.utils.clip_grad_norm_(model.parameters(), grad_clip_thresh) # Update weights optimizer.step_and_update_lr() optimizer.zero_grad() if step % log_step == 0: losses = [l.item() for l in losses] message1 = "Step {}/{}|".format(step, total_step) message2 = "|Total Loss: {:.4f}|Mel Loss: {:.4f}|Mel PostNet Loss: {:.4f}|Pitch Loss: {:.4f}|Energy Loss: {:.4f}|Duration Loss: {:.4f}|".format( *losses) # with open(os.path.join(train_log_path, "log.txt"), "a") as f: # f.write(message1 + message2 + "\n") outer_bar.write(message1 + message2) log(train_logger, step, losses=losses) if step % synth_step == 0: output_preidiction = model(*(batch[2:6])) fig, wav_reconstruction, wav_prediction, tag = synth_one_sample( batch, output_preidiction, vocoder, model_config, preprocess_config, ) log( train_logger, fig=fig, tag="Training/step_{}_{}".format(step, tag), ) sampling_rate = preprocess_config["preprocessing"][ "audio"]["sampling_rate"] log( train_logger, audio=wav_reconstruction, sampling_rate=sampling_rate, tag="Training/step_{}_{}_reconstructed".format( step, tag), ) log( train_logger, audio=wav_prediction, sampling_rate=sampling_rate, tag="Training/step_{}_{}_synthesized".format( step, tag), ) if step % val_step == 0: model.eval() message = evaluate(model, step, configs, val_logger, vocoder) # with open(os.path.join(val_log_path, "log.txt"), "a") as f: # f.write(message + "\n") outer_bar.write(message) model.train() if step % save_step == 0: torch.save( { "model": model.module.state_dict(), "optimizer": optimizer._optimizer.state_dict(), }, os.path.join( train_config["path"]["ckpt_path"], "{}.pth.tar".format(step), ), ) if step == total_step: quit() step += 1 outer_bar.update(1) inner_bar.update(1) epoch += 1