Пример #1
0
def main(
    data_dir,
    save_dir,
    total_steps,
    warmup_steps,
    valid_steps,
    log_steps,
    save_steps,
    milestones,
    exclusive_rate,
    n_samples,
    accu_steps,
    batch_size,
    n_workers,
    preload,
    comment,
    ckpt,
    grad_norm_clip,
    use_target_features,
    **kwargs,
):
    """Main function."""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    metadata_path = Path(data_dir) / "metadata.json"

    dataset = IntraSpeakerDataset(data_dir,
                                  metadata_path,
                                  n_samples,
                                  preload,
                                  ref_feat=use_target_features)
    trainlen = int(0.9 * len(dataset))
    lengths = [trainlen, len(dataset) - trainlen]
    trainset, validset = random_split(dataset, lengths)
    train_loader = DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        drop_last=True,
        num_workers=n_workers,
        pin_memory=True,
        collate_fn=collate_batch,
    )
    valid_loader = DataLoader(
        validset,
        batch_size=batch_size * accu_steps,
        num_workers=n_workers,
        drop_last=True,
        pin_memory=True,
        collate_fn=collate_batch,
    )
    train_iterator = iter(train_loader)

    if comment is not None:
        log_dir = "logs/"
        log_dir += datetime.datetime.now().strftime("%Y-%m-%d_%H:%M:%S")
        log_dir += "_" + comment
        writer = SummaryWriter(log_dir)

    save_dir_path = Path(save_dir)
    save_dir_path.mkdir(parents=True, exist_ok=True)

    if ckpt is not None:
        try:
            start_step = int(ckpt.split('-')[1][4:])
            ref_included = True
        except:
            start_step = 0
            ref_included = False

        model = torch.jit.load(ckpt).to(device)
        optimizer = RAdam(
            [
                {
                    "params": model.unet.parameters(),
                    "lr": 1e-6
                },
                {
                    "params": model.smoothers.parameters()
                },
                {
                    "params": model.mel_linear.parameters()
                },
                {
                    "params": model.post_net.parameters()
                },
            ],
            lr=1e-4,
        )
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps,
                                                    total_steps - start_step)
        print("Optimizer and scheduler restarted.")
        print(f"Model loaded from {ckpt}, iteration: {start_step}")
    else:
        ref_included = False
        start_step = 0

        model = FragmentVC().to(device)
        model = torch.jit.script(model)
        optimizer = RAdam(model.parameters(), lr=1e-4)
        scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps,
                                                    total_steps)

    criterion = nn.L1Loss()

    best_loss = float("inf")
    best_state_dict = None

    self_exclude = 0.0

    pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

    for step in range(start_step, total_steps):
        batch_loss = 0.0

        for _ in range(accu_steps):
            try:
                batch = next(train_iterator)
            except StopIteration:
                train_iterator = iter(train_loader)
                batch = next(train_iterator)

            loss = model_fn(batch, model, criterion, self_exclude,
                            ref_included, device)
            loss = loss / accu_steps
            batch_loss += loss.item()
            loss.backward()

        optimizer.step()
        scheduler.step()
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_norm_clip)
        optimizer.zero_grad()

        pbar.update()
        pbar.set_postfix(loss=f"{batch_loss:.2f}",
                         excl=self_exclude,
                         step=step + 1)

        if step % log_steps == 0 and comment is not None:
            writer.add_scalar("Loss/train", batch_loss, step)
            writer.add_scalar("Self-exclusive Rate", self_exclude, step)

        if (step + 1) % valid_steps == 0:
            pbar.close()

            valid_loss = valid(valid_loader, model, criterion, device)

            if comment is not None:
                writer.add_scalar("Loss/valid", valid_loss, step + 1)

            if valid_loss < best_loss:
                best_loss = valid_loss
                best_state_dict = model.state_dict()

            pbar = tqdm(total=valid_steps, ncols=0, desc="Train", unit=" step")

        if (step + 1) % save_steps == 0 and best_state_dict is not None:
            loss_str = f"{best_loss:.4f}".replace(".", "dot")
            best_ckpt_name = f"retriever-best-loss{loss_str}.pt"

            loss_str = f"{valid_loss:.4f}".replace(".", "dot")
            curr_ckpt_name = f"retriever-step{step+1}-loss{loss_str}.pt"

            current_state_dict = model.state_dict()
            model.cpu()

            model.load_state_dict(best_state_dict)
            model.save(str(save_dir_path / best_ckpt_name))

            model.load_state_dict(current_state_dict)
            model.save(str(save_dir_path / curr_ckpt_name))

            model.to(device)
            pbar.write(
                f"Step {step + 1}, best model saved. (loss={best_loss:.4f})")

        if (step + 1) >= milestones[1]:
            self_exclude = exclusive_rate

        elif (step + 1) == milestones[0]:
            ref_included = True
            optimizer = RAdam(
                [
                    {
                        "params": model.unet.parameters(),
                        "lr": 1e-6
                    },
                    {
                        "params": model.smoothers.parameters()
                    },
                    {
                        "params": model.mel_linear.parameters()
                    },
                    {
                        "params": model.post_net.parameters()
                    },
                ],
                lr=1e-4,
            )
            scheduler = get_cosine_schedule_with_warmup(
                optimizer, warmup_steps, total_steps - milestones[0])
            pbar.write("Optimizer and scheduler restarted.")

        elif (step + 1) > milestones[0]:
            self_exclude = (step + 1 - milestones[0]) / (milestones[1] -
                                                         milestones[0])
            self_exclude *= exclusive_rate

    pbar.close()
Пример #2
0
def train(rank: int, cfg: DictConfig):
    print(OmegaConf.to_yaml(cfg))

    if cfg.train.n_gpu > 1:
        init_process_group(backend=cfg.train.dist_config['dist_backend'],
                           init_method=cfg.train.dist_config['dist_url'],
                           world_size=cfg.train.dist_config['world_size'] *
                           cfg.train.n_gpu,
                           rank=rank)

    device = torch.device(
        'cuda:{:d}'.format(rank) if torch.cuda.is_available() else 'cpu')

    generator = Generator(sum(cfg.model.feature_dims), *cfg.model.cond_dims,
                          **cfg.model.generator).to(device)
    discriminator = Discriminator(**cfg.model.discriminator).to(device)

    if rank == 0:
        print(generator)
        os.makedirs(cfg.train.ckpt_dir, exist_ok=True)
        print("checkpoints directory : ", cfg.train.ckpt_dir)

    if os.path.isdir(cfg.train.ckpt_dir):
        cp_g = scan_checkpoint(cfg.train.ckpt_dir, 'g_')
        cp_do = scan_checkpoint(cfg.train.ckpt_dir, 'd_')

    steps = 1
    if cp_g is None or cp_do is None:
        state_dict_do = None
        last_epoch = -1
    else:
        state_dict_g = load_checkpoint(cp_g, device)
        state_dict_do = load_checkpoint(cp_do, device)
        generator.load_state_dict(state_dict_g['generator'])
        discriminator.load_state_dict(state_dict_do['discriminator'])
        steps = state_dict_do['steps'] + 1
        last_epoch = state_dict_do['epoch']

    if cfg.train.n_gpu > 1:
        generator = DistributedDataParallel(generator,
                                            device_ids=[rank]).to(device)
        discriminator = DistributedDataParallel(discriminator,
                                                device_ids=[rank]).to(device)

    optim_g = RAdam(generator.parameters(), cfg.opt.lr, betas=cfg.opt.betas)
    optim_d = RAdam(discriminator.parameters(),
                    cfg.opt.lr,
                    betas=cfg.opt.betas)

    if state_dict_do is not None:
        optim_g.load_state_dict(state_dict_do['optim_g'])
        optim_d.load_state_dict(state_dict_do['optim_d'])

    scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
        optim_g, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)
    scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
        optim_d, gamma=cfg.opt.lr_decay, last_epoch=last_epoch)

    train_filelist = load_dataset_filelist(cfg.dataset.train_list)
    trainset = FeatureDataset(cfg.dataset, train_filelist, cfg.data)
    train_sampler = DistributedSampler(
        trainset) if cfg.train.n_gpu > 1 else None
    train_loader = DataLoader(trainset,
                              batch_size=cfg.train.batch_size,
                              num_workers=cfg.train.num_workers,
                              shuffle=True,
                              sampler=train_sampler,
                              pin_memory=True,
                              drop_last=True)

    if rank == 0:
        val_filelist = load_dataset_filelist(cfg.dataset.test_list)
        valset = FeatureDataset(cfg.dataset,
                                val_filelist,
                                cfg.data,
                                segmented=False)
        val_loader = DataLoader(valset,
                                batch_size=1,
                                num_workers=cfg.train.num_workers,
                                shuffle=False,
                                sampler=train_sampler,
                                pin_memory=True)

        sw = SummaryWriter(os.path.join(cfg.train.ckpt_dir, 'logs'))

    generator.train()
    discriminator.train()
    for epoch in range(max(0, last_epoch), cfg.train.epochs):
        if rank == 0:
            start = time.time()
            print("Epoch: {}".format(epoch + 1))

        if cfg.train.n_gpu > 1:
            train_sampler.set_epoch(epoch)

        for y, x_noised_features, x_noised_cond in train_loader:
            if rank == 0:
                start_b = time.time()

            y = y.to(device, non_blocking=True)
            x_noised_features = x_noised_features.transpose(1, 2).to(
                device, non_blocking=True)
            x_noised_cond = x_noised_cond.to(device, non_blocking=True)
            z1 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)
            z2 = torch.randn(cfg.train.batch_size,
                             cfg.model.cond_dims[1],
                             device=device)

            y_hat1 = generator(x_noised_features, x_noised_cond, z=z1)
            y_hat2 = generator(x_noised_features, x_noised_cond, z=z2)

            # Discriminator
            real_scores, fake_scores = discriminator(y), discriminator(
                y_hat1.detach())
            d_loss = discriminator_loss(real_scores, fake_scores)

            optim_d.zero_grad()
            d_loss.backward(retain_graph=True)
            optim_d.step()

            # Generator
            g_stft_loss = criterion(y, y_hat1) + criterion(
                y, y_hat2) - criterion(y_hat1, y_hat2)
            g_adv_loss = adversarial_loss(fake_scores)
            g_loss = g_adv_loss + g_stft_loss

            optim_g.zero_grad()
            g_loss.backward()
            optim_g.step()

            if rank == 0:
                # STDOUT logging
                if steps % cfg.train.stdout_interval == 0:
                    with torch.no_grad():
                        print(
                            'Steps : {:d}, Gen Loss Total : {:4.3f}, STFT Error : {:4.3f}, s/b : {:4.3f}'
                            .format(steps, g_loss, g_stft_loss,
                                    time.time() - start_b))

                # checkpointing
                if steps % cfg.train.checkpoint_interval == 0:
                    ckpt_dir = "{}/g_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'generator':
                            (generator.module if cfg.train.n_gpu > 1 else
                             generator).state_dict()
                        })
                    ckpt_dir = "{}/do_{:08d}".format(cfg.train.ckpt_dir, steps)
                    save_checkpoint(
                        ckpt_dir, {
                            'discriminator':
                            (discriminator.module if cfg.train.n_gpu > 1 else
                             discriminator).state_dict(),
                            'optim_g':
                            optim_g.state_dict(),
                            'optim_d':
                            optim_d.state_dict(),
                            'steps':
                            steps,
                            'epoch':
                            epoch
                        })

                # Tensorboard summary logging
                if steps % cfg.train.summary_interval == 0:
                    sw.add_scalar("training/gen_loss_total", g_loss, steps)
                    sw.add_scalar("training/gen_stft_error", g_stft_loss,
                                  steps)

                # Validation
                if steps % cfg.train.validation_interval == 0:
                    generator.eval()
                    torch.cuda.empty_cache()
                    val_err_tot = 0
                    with torch.no_grad():
                        for j, (y, x_noised_features,
                                x_noised_cond) in enumerate(val_loader):
                            y_hat = generator(
                                x_noised_features.transpose(1, 2).to(device),
                                x_noised_cond.to(device))
                            val_err_tot += criterion(y, y_hat).item()

                            if j <= 4:
                                # sw.add_audio('noised/y_noised_{}'.format(j), y_noised[0], steps, cfg.data.target_sample_rate)
                                sw.add_audio('generated/y_hat_{}'.format(j),
                                             y_hat[0], steps,
                                             cfg.data.sample_rate)
                                sw.add_audio('gt/y_{}'.format(j), y[0], steps,
                                             cfg.data.sample_rate)

                        val_err = val_err_tot / (j + 1)
                        sw.add_scalar("validation/stft_error", val_err, steps)

                    generator.train()

            steps += 1

        scheduler_g.step()
        scheduler_d.step()

        if rank == 0:
            print('Time taken for epoch {} is {} sec\n'.format(
                epoch + 1, int(time.time() - start)))