def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "demiurge" project = "melgan" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id batch_size = args.batch_size # Getting initial run steps and epoch # if restore run, replace args steps = None if restore_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep prev_args = argparse.Namespace(**previous_run.config) args = vars(args) args.update(vars(prev_args)) args = Namespace(**args) args.batch_size = batch_size load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate ratios = args.ratios if isinstance(ratios, str): ratios = ratios.replace(" ", "") ratios = ratios.strip("][").split(",") ratios = [int(i) for i in ratios] ratios = np.array(ratios) if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") # read 1 line in train files to log dataset location train_files = Path(args.data_path) / "train_files.txt" with open(train_files, encoding="utf-8", mode="r") as f: file = f.readline() args.train_file_sample = str(file) wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers, ratios=ratios).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for model, filenames in [ (netG, ["netG.pt", "netG_prev.pt"]), (optG, ["optG.pt", "optG_prev.pt"]), (netD, ["netD.pt", "netD_prev.pt"]), (optD, ["optD.pt", "optD_prev.pt"]), ]: recover_model = False filepath = None for filename in filenames: try: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) filepath = restored_file.name model = load_state_dict_handleDP(model, filepath) recover_model = True break except RuntimeError as e: print("RuntimeError", e) print(f"recover model weight file: '{filename}'' failed") if not recover_model: raise RuntimeError( f"Cannot load model weight files for component {filenames[0]}." ) else: # store successfully recovered model weight file ("***_prev.pt") path_parent = Path(filepath).parent newfilepath = str(path_parent / filenames[1]) os.rename(filepath, newfilepath) wandb.save(newfilepath) if torch.cuda.device_count() > 1: netG = DP(netG).to(device) netD = DP(netD).to(device) fft = DP(fft).to(device) print(f"We have {torch.cuda.device_count()} gpus. Use data parallel.") else: print(f"We have {torch.cuda.device_count()} gpu.") ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") if not restore_run_id: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] melImages = [] num_fix_samples = args.n_test_samples - (args.n_test_samples // 2) cmap = cm.get_cmap("inferno") for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) melImage = s_t.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / (np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize((melImage.width * 4, melImage.height * 4)) melImages.append(wandb.Image(cmap(melImage), caption=f"sample {i}")) if i == num_fix_samples - 1: break # if not resume_run_id: wandb.log({"audio/original": samples}, step=start_epoch) wandb.log({"mel/original": melImages}, step=start_epoch) # else: # print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] melImages = [] # fix samples for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) melImages.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) wandb.log( { "audio/generated": samples, "mel/generated": melImages, "epoch": epoch, }, step=steps, ) # var samples source = [] pred = [] pred_mel = [] num_var_samples = args.n_test_samples - num_fix_samples for i, x_t in enumerate(test_loader): # source x_t = x_t.to(device) audio = x_t.squeeze().cpu() source.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) # pred s_t = fft(x_t).detach() voc = s_t.to(device) pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() pred.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) melImage = voc.squeeze().detach().cpu().numpy() melImage = (melImage - np.amin(melImage)) / ( np.amax(melImage) - np.amin(melImage)) # melImage = Image.fromarray(np.uint8(cmap(melImage)) * 255) # melImage = melImage.resize( # (melImage.width * 4, melImage.height * 4) # ) pred_mel.append( wandb.Image(cmap(melImage), caption=f"sample {i}")) # stop when reach log sample if i == num_var_samples - 1: break wandb.log( { "audio/var_original": source, "audio/var_generated": pred, "mel/var_generated": pred_mel, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def main(): device = torch.device("cuda" if torch.cuda.is_available() else "cpu") seed_everything(7) args = parse_args() Path(args.save_path).mkdir(parents=True, exist_ok=True) entity = "materialvision" project = "melganmv" load_from_run_id = args.load_from_run_id resume_run_id = args.resume_run_id restore_run_id = load_from_run_id or resume_run_id load_initial_weights = bool(restore_run_id) sampling_rate = args.sampling_rate if load_from_run_id and resume_run_id: raise RuntimeError("Specify either --load_from_id or --resume_run_id.") if resume_run_id: print(f"Resuming run ID {resume_run_id}.") elif load_from_run_id: print( f"Starting new run with initial weights from run ID {load_from_run_id}." ) else: print("Starting new run from scratch.") wandb.init( entity=entity, project=project, id=resume_run_id, config=args, resume=True if resume_run_id else False, save_code=True, dir=args.save_path, notes=args.notes, ) print("run id: " + str(wandb.run.id)) print("run name: " + str(wandb.run.name)) root = Path(wandb.run.dir) root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) wandb.save("args.yml") ############################################### # The file modules.py is needed by the unagan # ############################################### wandb.save(mel2wav.modules.__file__, base_path=".") ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).to(device) netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).to(device) fft = Audio2Mel( n_mel_channels=args.n_mel_channels, pad_mode=args.pad_mode, sampling_rate=sampling_rate, ).to(device) for model in [netG, netD, fft]: wandb.watch(model) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=args.learning_rate, betas=(0.5, 0.9)) if load_initial_weights: for obj, filename in [ (netG, "netG.pt"), (optG, "optG.pt"), (netD, "netD.pt"), (optD, "optD.pt"), ]: run_path = f"{entity}/{project}/{restore_run_id}" print(f"Restoring {filename} from run path {run_path}") restored_file = wandb.restore(filename, run_path=run_path) obj.load_state_dict(torch.load(restored_file.name)) ####################### # Create data loaders # ####################### train_set = AudioDataset( Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=sampling_rate, ) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", sampling_rate * 4, sampling_rate=sampling_rate, augment=False, ) wandb.save(str(Path(args.data_path) / "train_files.txt")) wandb.save(str(Path(args.data_path) / "test_files.txt")) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) if len(train_loader) == 0: raise RuntimeError("Train dataset is empty.") if len(test_loader) == 0: raise RuntimeError("Test dataset is empty.") # Getting initial run steps and epoch if load_from_run_id: api = wandb.Api() previous_run = api.run(f"{entity}/{project}/{restore_run_id}") steps = previous_run.lastHistoryStep else: steps = wandb.run.step start_epoch = steps // len(train_loader) print(f"Starting with epoch {start_epoch} and step {steps}.") ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] samples = [] for i, x_t in enumerate(test_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() test_voc.append(s_t.to(device)) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), sampling_rate, audio) samples.append( wandb.Audio(audio, caption=f"sample {i}", sample_rate=sampling_rate)) if i == args.n_test_samples - 1: break if not resume_run_id: wandb.log({"audio/original": samples}, step=0) else: print("We are resuming, skipping logging of original audio.") costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 for epoch in range(start_epoch, start_epoch + args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.to(device) s_t = fft(x_t).detach() x_pred_t = netG(s_t.to(device)) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.to(device).detach()) D_real = netD(x_t.to(device)) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.to(device)) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) wandb.log( { "loss/discriminator": costs[-1][0], "loss/generator": costs[-1][1], "loss/feature_matching": costs[-1][2], "loss/mel_reconstruction": costs[-1][3], }, step=steps, ) steps += 1 if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): samples = [] for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), sampling_rate, pred_audio) samples.append( wandb.Audio( pred_audio, caption=f"sample {i}", sample_rate=sampling_rate, )) wandb.log( { "audio/generated": samples, "epoch": epoch, }, step=steps, ) print("Saving models ...") torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") wandb.save(str(root / "netG.pt")) wandb.save(str(root / "optG.pt")) torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") wandb.save(str(root / "netD.pt")) wandb.save(str(root / "optD.pt")) if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") wandb.save(str(root / "best_netD.pt")) wandb.save(str(root / "best_netG.pt")) print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def main(): args = parse_args() root = Path(args.save_path) load_root = Path(args.load_path) if args.load_path else None root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) writer = SummaryWriter(str(root)) ####################### # Load PyTorch Models # ####################### netG = Generator(args.n_mel_channels, args.ngf, args.n_residual_layers).cuda() netD = Discriminator(args.num_D, args.ndf, args.n_layers_D, args.downsamp_factor).cuda() fft = Audio2Mel(n_mel_channels=args.n_mel_channels).cuda() #print(netG) #print(netD) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=1e-4, betas=(0.5, 0.9)) optD = torch.optim.Adam(netD.parameters(), lr=1e-4, betas=(0.5, 0.9)) if load_root and load_root.exists(): #print('불러와야지') netG.load_state_dict(torch.load(load_root / "netG.pt")) optG.load_state_dict(torch.load(load_root / "optG.pt")) netD.load_state_dict(torch.load(load_root / "netD.pt")) optD.load_state_dict(torch.load(load_root / "optD.pt")) ####################### # Create data loaders # ####################### train_set = AudioDataset(Path(args.data_path) / "train_files.txt", args.seq_len, sampling_rate=22050) test_set = AudioDataset( Path(args.data_path) / "test_files.txt", 22050 * 4, sampling_rate=22050, augment=False, ) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=4) test_loader = DataLoader(test_set, batch_size=1) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() test_voc.append(s_t.cuda()) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), 22050, audio) writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=22050) if i == args.n_test_samples - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 steps = 0 for epoch in range(1, args.epochs + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() x_pred_t = netG(s_t.cuda()) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ####################### # Train Discriminator # ####################### D_fake_det = netD(x_pred_t.cuda().detach()) D_real = netD(x_t.cuda()) loss_D = 0 for scale in D_fake_det: loss_D += F.relu(1 + scale[-1]).mean() for scale in D_real: loss_D += F.relu(1 - scale[-1]).mean() netD.zero_grad() loss_D.backward() optD.step() ################### # Train Generator # ################### D_fake = netD(x_pred_t.cuda()) loss_G = 0 for scale in D_fake: loss_G += -scale[-1].mean() loss_feat = 0 feat_weights = 4.0 / (args.n_layers_D + 1) D_weights = 1.0 / args.num_D wt = D_weights * feat_weights for i in range(args.num_D): for j in range(len(D_fake[i]) - 1): loss_feat += wt * F.l1_loss(D_fake[i][j], D_real[i][j].detach()) netG.zero_grad() (loss_G + args.lambda_feat * loss_feat).backward() optG.step() ###################### # Update tensorboard # ###################### costs.append( [loss_D.item(), loss_G.item(), loss_feat.item(), s_error]) writer.add_scalar("loss/discriminator", costs[-1][0], steps) writer.add_scalar("loss/generator", costs[-1][1], steps) writer.add_scalar("loss/feature_matching", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) steps += 1 #print('파라미터갯수 : ',netD.p_num) if steps % args.save_interval == 0: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample( root / ("generated_{}_{}.wav".format(epoch, i)), 22050, pred_audio) writer.add_audio( "generated/sample_%d.wav" % i, pred_audio, epoch, sample_rate=22050, ) torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") torch.save(netD.state_dict(), root / "netD.pt") torch.save(optD.state_dict(), root / "optD.pt") if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netD.state_dict(), root / "best_netD.pt") torch.save(netG.state_dict(), root / "best_netG.pt") print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args.log_interval == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args.log_interval, np.asarray(costs).mean(0), )) costs = [] start = time.time()
def trainG(args): root = Path(args['logging']['save_path']) load_root = Path( args['logging']['load_path']) if args['logging']['load_path'] else None root.mkdir(parents=True, exist_ok=True) #################################### # Dump arguments and create logger # #################################### with open(root / "args.yml", "w") as f: yaml.dump(args, f) writer = SummaryWriter(str(root)) ####################### # Load PyTorch Models # ####################### netG = Generator(args['fft']['n_mel_channels'], args['Generator']['ngf'], args['Generator']['n_residual_layers'], ratios=args['Generator']['ratios']).cuda() if 'G_path' in args['Generator'] and args['Generator'][ 'G_path'] is not None: netG.load_state_dict( torch.load(args['Generator']['G_path'] / "netG.pt")) fft = Audio2Mel(n_mel_channels=args['fft']['n_mel_channels'], n_fft=args['fft']['n_fft'], hop_length=args['fft']['hop_length'], win_length=args['fft']['win_length'], sampling_rate=args['data']['sampling_rate'], mel_fmin=args['fft']['mel_fmin'], mel_fmax=args['fft']['mel_fmax']).cuda() print(netG) ##################### # Create optimizers # ##################### optG = torch.optim.Adam(netG.parameters(), lr=args['optimizer']['lrG'], betas=args['optimizer']['betasG']) if load_root and load_root.exists(): netG.load_state_dict(torch.load(load_root / "netG.pt")) optG.load_state_dict(torch.load(load_root / "optG.pt")) print('checkpoints loaded') ####################### # Create data loaders # ####################### train_set = AudioDataset(Path(args['data']['data_path']) / "train_files_inv.txt", segment_length=args['data']['seq_len'], sampling_rate=args['data']['sampling_rate'], augment=['amp', 'flip', 'neg']) test_set = AudioDataset(Path(args['data']['data_path']) / "test_files_inv.txt", segment_length=args['data']['sampling_rate'] * 4, sampling_rate=args['data']['sampling_rate'], augment=None) train_loader = DataLoader(train_set, batch_size=args['dataloader']['batch_size'], num_workers=4, pin_memory=True, shuffle=True) test_loader = DataLoader(test_set, batch_size=1) ########################## # Dumping original audio # ########################## test_voc = [] test_audio = [] for i, x_t in enumerate(test_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() test_voc.append(s_t.cuda()) test_audio.append(x_t) audio = x_t.squeeze().cpu() save_sample(root / ("original_%d.wav" % i), args['data']['sampling_rate'], audio) writer.add_audio("original/sample_%d.wav" % i, audio, 0, sample_rate=args['data']['sampling_rate']) if i == args['logging']['n_test_samples'] - 1: break costs = [] start = time.time() # enable cudnn autotuner to speed up training torch.backends.cudnn.benchmark = True best_mel_reconst = 1000000 steps = 0 mr_stft_loss = MultiResolutionSTFTLoss().cuda() for epoch in range(1, args['train']['epochs'] + 1): for iterno, x_t in enumerate(train_loader): x_t = x_t.cuda() s_t = fft(x_t).detach() x_pred_t = netG(s_t.cuda()) with torch.no_grad(): s_pred_t = fft(x_pred_t.detach()) s_error = F.l1_loss(s_t, s_pred_t).item() ################### # Train Generator # ################### loss_G = 0 sc, sm = mr_stft_loss(x_pred_t, x_t) loss_G = args['losses']['lambda_sc'] * sc + args['losses'][ 'lambda_sm'] * sm netG.zero_grad() loss_G.backward() optG.step() ###################### # Update tensorboard # ###################### costs.append([loss_G.item(), sc.item(), sm.item(), s_error]) writer.add_scalar("loss/generator", costs[-1][0], steps) writer.add_scalar("loss/convergence", costs[-1][1], steps) writer.add_scalar("loss/logmag", costs[-1][2], steps) writer.add_scalar("loss/mel_reconstruction", costs[-1][3], steps) steps += 1 if steps % args['logging']['save_interval'] == 0: st = time.time() with torch.no_grad(): for i, (voc, _) in enumerate(zip(test_voc, test_audio)): pred_audio = netG(voc) pred_audio = pred_audio.squeeze().cpu() save_sample(root / ("generated_%d.wav" % i), args['data']['sampling_rate'], pred_audio) writer.add_audio( "generated/sample_%d.wav" % i, pred_audio, epoch, sample_rate=args['data']['sampling_rate'], ) torch.save(netG.state_dict(), root / "netG.pt") torch.save(optG.state_dict(), root / "optG.pt") if np.asarray(costs).mean(0)[-1] < best_mel_reconst: best_mel_reconst = np.asarray(costs).mean(0)[-1] torch.save(netG.state_dict(), root / "best_netG.pt") print("Took %5.4fs to generate samples" % (time.time() - st)) print("-" * 100) if steps % args['logging']['log_interval'] == 0: print("Epoch {} | Iters {} / {} | ms/batch {:5.2f} | loss {}". format( epoch, iterno, len(train_loader), 1000 * (time.time() - start) / args['logging']['log_interval'], np.asarray(costs).mean(0), )) costs = [] start = time.time()