def train(): print("training begins...") # training and validation relative directories train_directory = "../../ETTT/Pytorch-DCTTS/LJSpeech_data/" val_directory = "../../ETTT/Pytorch-DCTTS/LJSpeech_val/" t_data = LJDataset(train_directory) v_data = LJDataset(val_directory) train_len = len(t_data.bases) val_len = len(v_data.bases) # training parameters batch_size = 40 epochs = 500 save_every = 5 learning_rate = 1e-4 max_grad_norm = 1.0 # create model and optim hp = Hparams() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") model = SSRN(hp, device) optim = torch.optim.Adam(model.parameters(), lr=learning_rate) # main training loop for ep in tqdm(range(epochs)): total_loss = 0 # epoch loss t_loader = DataLoader(t_data, batch_size=batch_size, shuffle=True, drop_last=False, collate_fn=ssrn_collate_fn) for data in tqdm(t_loader): # initialize batch_loss batch_loss = model.compute_batch_loss(data) # batch update optim.zero_grad() batch_loss.backward() torch.nn.utils.clip_grad_norm_(max_norm=max_grad_norm, parameters=model.parameters()) optim.step() total_loss += batch_loss.detach().cpu().numpy() # one epoch complete, add to total loss and print print( "epoch {}, total loss:{}, average total loss:{}, validating now..." .format(ep, float(total_loss), float(total_loss) / train_len)) # if time to save, we save model if ep % save_every == 0: torch.save( model.state_dict(), "save_stuff/checkpoint/epoch_" + str(ep) + "_ssrn_model.pt") # Validation phase with torch.no_grad(): total_loss = 0 v_loader = DataLoader(v_data, batch_size=batch_size // 10, shuffle=True, drop_last=False, collate_fn=ssrn_collate_fn) for data in tqdm(v_loader): loss = model.compute_batch_loss(data) total_loss += loss.detach().cpu().numpy() # printing print("validation loss:{}, average validation loss:{}".format( float(total_loss), float(total_loss) / val_len)) for dat in data: x, y = dat # predict predict, _ = model.forward((x.view(1, 80, -1)).to(device)) np.save("save_stuff/mel_pred/epoch_" + str(ep) + "_mel_pred.npy", predict.detach().cpu().numpy()) np.save( "save_stuff/mel_pred/epoch_" + str(ep) + "_ground_truth.npy", y)
else: print( "Will fall back to saved config. If you want to use the current config file, run with flag " "'-cc'\n") Config.set_config(conf) # Tensorboard writer = SummaryWriter(args.log_dir) print("Loading SSRN...") net = SSRN().to(device) net.apply(weight_init) l1_criterion = nn.L1Loss().to(device) bd_criterion = nn.BCEWithLogitsLoss().to(device) optimizer = torch.optim.Adam(net.parameters(), lr=0.001) global_step = 0 # Learning rate decay. Noam scheme warmup_steps = 4000.0 def decay(_): step = global_step + 1 return warmup_steps**0.5 * min(step * warmup_steps**-1.5, step**-0.5) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=decay) if args.restore_path is not None: print("Restoring from checkpoint: {}".format(args.restore_path)) state = torch.load(args.restore_path, map_location=device) global_step = state["global_step"]