def main( data_dir, save_dir, total_steps, warmup_steps, valid_steps, log_steps, save_steps, milestones, exclusive_rate, n_samples, accu_steps, batch_size, n_workers, preload, comment, ckpt, grad_norm_clip, use_target_features, **kwargs, ): """Main function.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") metadata_path = Path(data_dir) / "metadata.json" dataset = IntraSpeakerDataset(data_dir, metadata_path, n_samples, preload, ref_feat=use_target_features) trainlen = int(0.9 * len(dataset)) lengths = [trainlen, len(dataset) - trainlen] trainset, validset = random_split(dataset, lengths) train_loader = DataLoader( trainset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=n_workers, pin_memory=True, collate_fn=collate_batch, ) valid_loader = DataLoader( validset, batch_size=batch_size * accu_steps, num_workers=n_workers, drop_last=True, pin_memory=True, collate_fn=collate_batch, ) train_iterator = iter(train_loader) if comment is not None: log_dir = "logs/" log_dir += datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S") log_dir += "_" + comment writer = SummaryWriter(log_dir) save_dir_path = Path(save_dir) save_dir_path.mkdir(parents=True, exist_ok=True) if ckpt is not None: try: start_step = int(ckpt.split('-')[1][4:]) ref_included = True except: start_step = 0 ref_included = False model = torch.jit.load(ckpt).to(device) optimizer = RAdam( [ { "params": model.unet.parameters(), "lr": 1e-6 }, { "params": model.smoothers.parameters() }, { "params": model.mel_linear.parameters() }, { "params": model.post_net.parameters() }, ], lr=1e-4, ) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps - start_step) print("Optimizer and scheduler restarted.") print(f"Model loaded from {ckpt}, iteration: {start_step}") else: ref_included = False start_step = 0 model = FragmentVC().to(device) model = torch.jit.script(model) optimizer = RAdam(model.parameters(), lr=1e-4) scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps) criterion = nn.L1Loss() best_loss = float("inf") best_state_dict = None self_exclude = 0.0 pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") for step in range(start_step, total_steps): batch_loss = 0.0 for _ in range(accu_steps): try: batch = next(train_iterator) except StopIteration: train_iterator = iter(train_loader) batch = next(train_iterator) loss = model_fn(batch, model, criterion, self_exclude, ref_included, device) loss = loss / accu_steps batch_loss += loss.item() loss.backward() optimizer.step() scheduler.step() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip) optimizer.zero_grad() pbar.update() pbar.set_postfix(loss=f"{batch_loss:.2f}", excl=self_exclude, step=step + 1) if step % log_steps == 0 and comment is not None: writer.add_scalar("Loss/train", batch_loss, step) writer.add_scalar("Self-exclusive Rate", self_exclude, step) if (step + 1) % valid_steps == 0: pbar.close() valid_loss = valid(valid_loader, model, criterion, device) if comment is not None: writer.add_scalar("Loss/valid", valid_loss, step + 1) if valid_loss < best_loss: best_loss = valid_loss best_state_dict = model.state_dict() pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step") if (step + 1) % save_steps == 0 and best_state_dict is not None: loss_str = f"{best_loss:.4f}".replace(".", "dot") best_ckpt_name = f"retriever-best-loss{loss_str}.pt" loss_str = f"{valid_loss:.4f}".replace(".", "dot") curr_ckpt_name = f"retriever-step{step+1}-loss{loss_str}.pt" current_state_dict = model.state_dict() model.cpu() model.load_state_dict(best_state_dict) model.save(str(save_dir_path / best_ckpt_name)) model.load_state_dict(current_state_dict) model.save(str(save_dir_path / curr_ckpt_name)) model.to(device) pbar.write( f"Step {step + 1}, best model saved. (loss={best_loss:.4f})") if (step + 1) >= milestones[1]: self_exclude = exclusive_rate elif (step + 1) == milestones[0]: ref_included = True optimizer = RAdam( [ { "params": model.unet.parameters(), "lr": 1e-6 }, { "params": model.smoothers.parameters() }, { "params": model.mel_linear.parameters() }, { "params": model.post_net.parameters() }, ], lr=1e-4, ) scheduler = get_cosine_schedule_with_warmup( optimizer, warmup_steps, total_steps - milestones[0]) pbar.write("Optimizer and scheduler restarted.") elif (step + 1) > milestones[0]: self_exclude = (step + 1 - milestones[0]) / (milestones[1] - milestones[0]) self_exclude *= exclusive_rate pbar.close()
def train(rank: int, cfg: DictConfig): print(OmegaConf.to_yaml(cfg)) if cfg.train.n_gpu > 1: init_process_group(backend=cfg.train.dist_config['dist_backend'], init_method=cfg.train.dist_config['dist_url'], world_size=cfg.train.dist_config['world_size'] * cfg.train.n_gpu, rank=rank) device = torch.device( 'cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu') generator = Generator(sum(cfg.model.feature_dims), *cfg.model.cond_dims, **cfg.model.generator).to(device) discriminator = Discriminator(**cfg.model.discriminator).to(device) if rank == 0: print(generator) os.makedirs(cfg.train.ckpt_dir, exist_ok=True) print("checkpoints directory : ", cfg.train.ckpt_dir) if os.path.isdir(cfg.train.ckpt_dir): cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_') cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_') steps = 1 if cp_g is None or cp_do is None: state_dict_do = None last_epoch = -1 else: state_dict_g = load_checkpoint(cp_g, device) state_dict_do = load_checkpoint(cp_do, device) generator.load_state_dict(state_dict_g['generator']) discriminator.load_state_dict(state_dict_do['discriminator']) steps = state_dict_do['steps'] + 1 last_epoch = state_dict_do['epoch'] if cfg.train.n_gpu > 1: generator = DistributedDataParallel(generator, device_ids=[rank]).to(device) discriminator = DistributedDataParallel(discriminator, device_ids=[rank]).to(device) optim_g = RAdam(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) optim_d = RAdam(discriminator.parameters(), cfg.opt.lr, betas=cfg.opt.betas) if state_dict_do is not None: optim_g.load_state_dict(state_dict_do['optim_g']) optim_d.load_state_dict(state_dict_do['optim_d']) scheduler_g = torch.optim.lr_scheduler.ExponentialLR( optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) scheduler_d = torch.optim.lr_scheduler.ExponentialLR( optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch) train_filelist = load_dataset_filelist(cfg.dataset.train_list) trainset = FeatureDataset(cfg.dataset, train_filelist, cfg.data) train_sampler = DistributedSampler( trainset) if cfg.train.n_gpu > 1 else None train_loader = DataLoader(trainset, batch_size=cfg.train.batch_size, num_workers=cfg.train.num_workers, shuffle=True, sampler=train_sampler, pin_memory=True, drop_last=True) if rank == 0: val_filelist = load_dataset_filelist(cfg.dataset.test_list) valset = FeatureDataset(cfg.dataset, val_filelist, cfg.data, segmented=False) val_loader = DataLoader(valset, batch_size=1, num_workers=cfg.train.num_workers, shuffle=False, sampler=train_sampler, pin_memory=True) sw = SummaryWriter(os.path.join(cfg.train.ckpt_dir, 'logs')) generator.train() discriminator.train() for epoch in range(max(0, last_epoch), cfg.train.epochs): if rank == 0: start = time.time() print("Epoch: {}".format(epoch + 1)) if cfg.train.n_gpu > 1: train_sampler.set_epoch(epoch) for y, x_noised_features, x_noised_cond in train_loader: if rank == 0: start_b = time.time() y = y.to(device, non_blocking=True) x_noised_features = x_noised_features.transpose(1, 2).to( device, non_blocking=True) x_noised_cond = x_noised_cond.to(device, non_blocking=True) z1 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) z2 = torch.randn(cfg.train.batch_size, cfg.model.cond_dims[1], device=device) y_hat1 = generator(x_noised_features, x_noised_cond, z=z1) y_hat2 = generator(x_noised_features, x_noised_cond, z=z2) # Discriminator real_scores, fake_scores = discriminator(y), discriminator( y_hat1.detach()) d_loss = discriminator_loss(real_scores, fake_scores) optim_d.zero_grad() d_loss.backward(retain_graph=True) optim_d.step() # Generator g_stft_loss = criterion(y, y_hat1) + criterion( y, y_hat2) - criterion(y_hat1, y_hat2) g_adv_loss = adversarial_loss(fake_scores) g_loss = g_adv_loss + g_stft_loss optim_g.zero_grad() g_loss.backward() optim_g.step() if rank == 0: # STDOUT logging if steps % cfg.train.stdout_interval == 0: with torch.no_grad(): print( 'Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}' .format(steps, g_loss, g_stft_loss, time.time() - start_b)) # checkpointing if steps % cfg.train.checkpoint_interval == 0: ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps) save_checkpoint( ckpt_dir, { 'generator': (generator.module if cfg.train.n_gpu > 1 else generator).state_dict() }) ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps) save_checkpoint( ckpt_dir, { 'discriminator': (discriminator.module if cfg.train.n_gpu > 1 else discriminator).state_dict(), 'optim_g': optim_g.state_dict(), 'optim_d': optim_d.state_dict(), 'steps': steps, 'epoch': epoch }) # Tensorboard summary logging if steps % cfg.train.summary_interval == 0: sw.add_scalar("training/gen_loss_total", g_loss, steps) sw.add_scalar("training/gen_stft_error", g_stft_loss, steps) # Validation if steps % cfg.train.validation_interval == 0: generator.eval() torch.cuda.empty_cache() val_err_tot = 0 with torch.no_grad(): for j, (y, x_noised_features, x_noised_cond) in enumerate(val_loader): y_hat = generator( x_noised_features.transpose(1, 2).to(device), x_noised_cond.to(device)) val_err_tot += criterion(y, y_hat).item() if j <= 4: # sw.add_audio('noised/y_noised_{}'.format(j), y_noised[0], steps, cfg.data.target_sample_rate) sw.add_audio('generated/y_hat_{}'.format(j), y_hat[0], steps, cfg.data.sample_rate) sw.add_audio('gt/y_{}'.format(j), y[0], steps, cfg.data.sample_rate) val_err = val_err_tot / (j + 1) sw.add_scalar("validation/stft_error", val_err, steps) generator.train() steps += 1 scheduler_g.step() scheduler_d.step() if rank == 0: print('Time taken for epoch {} is {} sec\n'.format( epoch + 1, int(time.time() - start)))