コード例 #1
0
    def validation_epoch_end(self, outputs):
        batch = outputs[0]["batch"]
        gc.collect()
        th.cuda.empty_cache()
        val_fid = validation.fid(
            self.g_ema.to(batch.device),
            self.val_batch_size,
            self.fid_n_sample,
            self.fid_truncation,
            self.name,
        )["FID"]
        val_ppl = validation.ppl(
            self.g_ema.to(batch.device),
            self.val_batch_size,
            self.ppl_n_sample,
            self.ppl_space,
            self.ppl_crop,
            self.latent_size,
        )
        with th.no_grad():
            self.g_ema.eval()
            sample, _ = self.g_ema(
                [self.sample_z.to(next(self.g_ema.parameters()).device)])
            grid = tv.utils.make_grid(sample,
                                      nrow=int(
                                          round(4.0 / 3 * self.n_sample**0.5)),
                                      normalize=True,
                                      range=(-1, 1))
            self.logger.experiment.log({
                "Generated Images EMA":
                [wandb.Image(grid, caption=f"Step {self.global_step}")]
            })

            self.generator.eval()
            sample, _ = self.generator(
                [self.sample_z.to(next(self.generator.parameters()).device)])
            grid = tv.utils.make_grid(sample,
                                      nrow=int(
                                          round(4.0 / 3 * self.n_sample**0.5)),
                                      normalize=True,
                                      range=(-1, 1))
            self.logger.experiment.log({
                "Generated Images":
                [wandb.Image(grid, caption=f"Step {self.global_step}")]
            })
            self.generator.train()

        # val_fid = [score for score in outputs[0]["FID"] if score != -69][0]
        # val_ppl = [score for score in outputs[0]["PPL"] if score != -69][0]

        gc.collect()
        th.cuda.empty_cache()

        return {
            "val_loss": val_fid,
            "log": {
                "Validation/FID": val_fid,
                "Validation/PPL": val_ppl
            }
        }
コード例 #2
0
ファイル: train.py プロジェクト: johndpope/maua-stylegan2
def train(args, loader, generator, discriminator, contrast_learner, augment, g_optim, d_optim, scaler, g_ema, device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar, initial=args.start_iter, dynamic_ncols=True, smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = th.zeros(size=(1,), device=device)
    g_loss_val = 0
    path_loss = th.zeros(size=(1,), device=device)
    path_lengths = th.zeros(size=(1,), device=device)
    loss_dict = {}
    mse = th.nn.MSELoss()

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        if contrast_learner is not None:
            cl_module = contrast_learner.module
    else:
        g_module = generator
        d_module = discriminator
        cl_module = contrast_learner

    sample_z = th.randn(args.n_sample, args.latent_size, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")
            break

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        discriminator.zero_grad()

        loss_dict["d"], loss_dict["real_score"], loss_dict["fake_score"] = 0, 0, 0
        loss_dict["cl_reg"], loss_dict["bc_reg"] = (
            th.tensor(0, device=device).float(),
            th.tensor(0, device=device).float(),
        )
        for _ in range(args.num_accumulate):
            # sample = []
            # for _ in range(0, len(sample_z), args.batch_size):
            #     subsample = next(loader)
            #     sample.append(subsample)
            # sample = th.cat(sample)
            # utils.save_image(sample, "reals-no-augment.png", nrow=10, normalize=True)
            # utils.save_image(augment(sample), "reals-augment.png", nrow=10, normalize=True)

            real_img = next(loader)
            real_img = real_img.to(device)

            # with th.cuda.amp.autocast():
            noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device)
            fake_img, _ = generator(noise)

            if args.augment_D:
                fake_pred = discriminator(augment(fake_img))
                real_pred = discriminator(augment(real_img))
            else:
                fake_pred = discriminator(fake_img)
                real_pred = discriminator(real_img)

            # logistic loss
            real_loss = F.softplus(-real_pred)
            fake_loss = F.softplus(fake_pred)
            d_loss = real_loss.mean() + fake_loss.mean()

            loss_dict["d"] += d_loss.detach()
            loss_dict["real_score"] += real_pred.mean().detach()
            loss_dict["fake_score"] += fake_pred.mean().detach()

            if i > 10000 or i == 0:
                if args.contrastive > 0:
                    contrast_learner(fake_img.clone().detach(), accumulate=True)
                    contrast_learner(real_img, accumulate=True)

                    contrast_loss = cl_module.calculate_loss()
                    loss_dict["cl_reg"] += contrast_loss.detach()

                    d_loss += args.contrastive * contrast_loss

                if args.balanced_consistency > 0:
                    aug_fake_pred = discriminator(augment(fake_img.clone().detach()))
                    aug_real_pred = discriminator(augment(real_img))

                    consistency_loss = mse(real_pred, aug_real_pred) + mse(fake_pred, aug_fake_pred)
                    loss_dict["bc_reg"] += consistency_loss.detach()

                    d_loss += args.balanced_consistency * consistency_loss

            d_loss /= args.num_accumulate
            # scaler.scale(d_loss).backward()
            d_loss.backward()

        # scaler.step(d_optim)
        d_optim.step()

        # R1 regularization
        if args.r1 > 0 and i % args.d_reg_every == 0:

            discriminator.zero_grad()

            loss_dict["r1"] = 0
            for _ in range(args.num_accumulate):
                real_img = next(loader)
                real_img = real_img.to(device)

                real_img.requires_grad = True

                # with th.cuda.amp.autocast():
                # if args.augment_D:
                #     real_pred = discriminator(
                #         augment(real_img)
                #     )  # RuntimeError: derivative for grid_sampler_2d_backward is not implemented :(
                # else:
                real_pred = discriminator(real_img)
                real_pred_sum = real_pred.sum()

                (grad_real,) = th.autograd.grad(outputs=real_pred_sum, inputs=real_img, create_graph=True)
                # (grad_real,) = th.autograd.grad(outputs=scaler.scale(real_pred_sum), inputs=real_img, create_graph=True)
                # grad_real = grad_real * (1.0 / scaler.get_scale())

                # with th.cuda.amp.autocast():
                r1_loss = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
                weighted_r1_loss = args.r1 / 2.0 * r1_loss * args.d_reg_every + 0 * real_pred[0]

                loss_dict["r1"] += r1_loss.detach()

                weighted_r1_loss /= args.num_accumulate
                # scaler.scale(weighted_r1_loss).backward()
                weighted_r1_loss.backward()

            # scaler.step(d_optim)
            d_optim.step()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        generator.zero_grad()
        loss_dict["g"] = 0
        for _ in range(args.num_accumulate):
            # with th.cuda.amp.autocast():
            noise = make_noise(args.batch_size, args.latent_size, args.mixing_prob, device)
            fake_img, _ = generator(noise)

            if args.augment_G:
                fake_img = augment(fake_img)

            fake_pred = discriminator(fake_img)

            # non-saturating loss
            g_loss = F.softplus(-fake_pred).mean()

            loss_dict["g"] += g_loss.detach()

            g_loss /= args.num_accumulate
            # scaler.scale(g_loss).backward()
            g_loss.backward()

        # scaler.step(g_optim)
        g_optim.step()

        # path length regularization
        if args.path_regularize > 0 and i % args.g_reg_every == 0:

            generator.zero_grad()

            loss_dict["path"], loss_dict["path_length"] = 0, 0
            for _ in range(args.num_accumulate):
                path_batch_size = max(1, args.batch_size // args.path_batch_shrink)

                # with th.cuda.amp.autocast():
                noise = make_noise(path_batch_size, args.latent_size, args.mixing_prob, device)
                fake_img, latents = generator(noise, return_latents=True)

                img_noise = th.randn_like(fake_img) / math.sqrt(fake_img.shape[2] * fake_img.shape[3])
                noisy_img_sum = (fake_img * img_noise).sum()

                (grad,) = th.autograd.grad(outputs=noisy_img_sum, inputs=latents, create_graph=True)
                # (grad,) = th.autograd.grad(outputs=scaler.scale(noisy_img_sum), inputs=latents, create_graph=True)
                # grad = grad * (1.0 / scaler.get_scale())

                # with th.cuda.amp.autocast():
                path_lengths = th.sqrt(grad.pow(2).sum(2).mean(1))
                path_mean = mean_path_length + 0.01 * (path_lengths.mean() - mean_path_length)
                path_loss = (path_lengths - path_mean).pow(2).mean()
                mean_path_length = path_mean.detach()

                loss_dict["path"] += path_loss.detach()
                loss_dict["path_length"] += path_lengths.mean().detach()

                weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
                if args.path_batch_shrink:
                    weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

                weighted_path_loss /= args.num_accumulate
                # scaler.scale(weighted_path_loss).backward()
                weighted_path_loss.backward()

            # scaler.step(g_optim)
            g_optim.step()

        # scaler.update()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item() / args.num_accumulate
        g_loss_val = loss_reduced["g"].mean().item() / args.num_accumulate
        cl_reg_val = loss_reduced["cl_reg"].mean().item() / args.num_accumulate
        bc_reg_val = loss_reduced["bc_reg"].mean().item() / args.num_accumulate
        r1_val = loss_reduced["r1"].mean().item() / args.num_accumulate
        path_loss_val = loss_reduced["path"].mean().item() / args.num_accumulate
        real_score_val = loss_reduced["real_score"].mean().item() / args.num_accumulate
        fake_score_val = loss_reduced["fake_score"].mean().item() / args.num_accumulate
        path_length_val = loss_reduced["path_length"].mean().item() / args.num_accumulate

        if get_rank() == 0:

            log_dict = {
                "Generator": g_loss_val,
                "Discriminator": d_loss_val,
                "Real Score": real_score_val,
                "Fake Score": fake_score_val,
                "Contrastive": cl_reg_val,
                "Consistency": bc_reg_val,
            }

            if args.log_spec_norm:
                G_norms = []
                for name, spec_norm in g_module.named_buffers():
                    if "spectral_norm" in name:
                        G_norms.append(spec_norm.cpu().numpy())
                G_norms = np.array(G_norms)
                D_norms = []
                for name, spec_norm in d_module.named_buffers():
                    if "spectral_norm" in name:
                        D_norms.append(spec_norm.cpu().numpy())
                D_norms = np.array(D_norms)

                log_dict[f"Spectral Norms/G min spectral norm"] = np.log(G_norms).min()
                log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(G_norms).mean()
                log_dict[f"Spectral Norms/G max spectral norm"] = np.log(G_norms).max()
                log_dict[f"Spectral Norms/D min spectral norm"] = np.log(D_norms).min()
                log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(D_norms).mean()
                log_dict[f"Spectral Norms/D max spectral norm"] = np.log(D_norms).max()

            if args.r1 > 0 and i % args.d_reg_every == 0:
                log_dict["R1"] = r1_val

            if args.path_regularize > 0 and i % args.g_reg_every == 0:
                log_dict["Path Length Regularization"] = path_loss_val
                log_dict["Mean Path Length"] = mean_path_length
                log_dict["Path Length"] = path_length_val

            if i % args.img_every == 0:
                gc.collect()
                th.cuda.empty_cache()
                with th.no_grad():
                    g_ema.eval()
                    sample = []
                    for sub in range(0, len(sample_z), args.batch_size):
                        subsample, _ = g_ema([sample_z[sub : sub + args.batch_size]])
                        sample.append(subsample.cpu())
                    sample = th.cat(sample)
                    grid = utils.make_grid(sample, nrow=10, normalize=True, range=(-1, 1))
                    # utils.save_image(sample, "fakes-no-augment.png", nrow=10, normalize=True)
                    # utils.save_image(augment(sample), "fakes-augment.png", nrow=10, normalize=True)
                    # exit()
                log_dict["Generated Images EMA"] = [wandb.Image(grid, caption=f"Step {i}")]

            if i % args.eval_every == 0:
                start_time = time.time()
                pbar.set_description((f"Calculating FID..."))
                fid_dict = validation.fid(g_ema, args.val_batch_size, args.fid_n_sample, args.fid_truncation, args.name)
                fid = fid_dict["FID"]
                density = fid_dict["Density"]
                coverage = fid_dict["Coverage"]

                pbar.set_description((f"Calculating PPL..."))
                ppl = validation.ppl(
                    g_ema, args.val_batch_size, args.ppl_n_sample, args.ppl_space, args.ppl_crop, args.latent_size,
                )

                pbar.set_description(
                    (
                        f"FID: {fid:.4f}; Density: {density:.4f}; Coverage: {coverage:.4f}; PPL: {ppl:.4f} in {time.time() - start_time:.1f}s"
                    )
                )
                log_dict["Evaluation/FID"] = fid
                log_dict["Evaluation/Density"] = density
                log_dict["Evaluation/Coverage"] = coverage
                log_dict["Evaluation/PPL"] = ppl

                gc.collect()
                th.cuda.empty_cache()

            wandb.log(log_dict)

            if i % args.checkpoint_every == 0:
                th.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        # "cl": cl_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"/home/hans/modelzoo/maua-sg2/{args.name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}-{int(fid)}-{int(ppl)}-{str(i).zfill(6)}.pt",
                )
コード例 #3
0
ファイル: train.py プロジェクト: monolesan/maua-stylegan2
def train(args, loader, generator, discriminator, contrast_learner, g_optim,
          d_optim, g_ema):
    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module
        if contrast_learner is not None:
            cl_module = contrast_learner.module
    else:
        g_module = generator
        d_module = discriminator
        cl_module = contrast_learner

    loader = sample_data(loader)
    sample_z = th.randn(args.n_sample, args.latent_size, device=device)
    mse = th.nn.MSELoss()
    mean_path_length = 0
    ada_augment = th.tensor([0.0, 0.0], device=device)
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    ada_aug_step = args.ada_target / args.ada_length
    r_t_stat = 0
    fids = []

    pbar = range(args.iter)
    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    dynamic_ncols=True,
                    smoothing=0)
    for idx in pbar:
        i = idx + args.start_iter
        if i > args.iter:
            print("Done!")
            break

        loss_dict = {
            "Generator": th.tensor(0, device=device).float(),
            "Discriminator": th.tensor(0, device=device).float(),
            "Real Score": th.tensor(0, device=device).float(),
            "Fake Score": th.tensor(0, device=device).float(),
            "Contrastive": th.tensor(0, device=device).float(),
            "Consistency": th.tensor(0, device=device).float(),
            "R1 Penalty": th.tensor(0, device=device).float(),
            "Path Length Regularization": th.tensor(0, device=device).float(),
            "Augment": th.tensor(0, device=device).float(),
            "Rt": th.tensor(0, device=device).float(),
        }

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        discriminator.zero_grad()
        for _ in range(args.num_accumulate):
            real_img_og = next(loader).to(device)
            noise = make_noise(args.batch_size, args.latent_size,
                               args.mixing_prob)
            fake_img_og, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img_og, ada_aug_p)
                real_img, _ = augment(real_img_og, ada_aug_p)
            else:
                fake_img = fake_img_og
                real_img = real_img_og

            fake_pred = discriminator(fake_img)
            real_pred = discriminator(real_img)
            logistic_loss = d_logistic_loss(real_pred, fake_pred)
            loss_dict["Discriminator"] += logistic_loss.detach()
            loss_dict["Real Score"] += real_pred.mean().detach()
            loss_dict["Fake Score"] += fake_pred.mean().detach()
            d_loss = logistic_loss

            if args.contrastive > 0:
                contrast_learner(fake_img_og, fake_img, accumulate=True)
                contrast_learner(real_img_og, real_img, accumulate=True)
                contrast_loss = cl_module.calculate_loss()
                loss_dict["Contrastive"] += contrast_loss.detach()
                d_loss += args.contrastive * contrast_loss

            if args.balanced_consistency > 0:
                consistency_loss = mse(
                    real_pred, discriminator(real_img_og)) + mse(
                        fake_pred, discriminator(fake_img_og))
                loss_dict["Consistency"] += consistency_loss.detach()
                d_loss += args.balanced_consistency * consistency_loss

            d_loss /= args.num_accumulate
            d_loss.backward()
        d_optim.step()

        if args.r1 > 0 and i % args.d_reg_every == 0:
            discriminator.zero_grad()
            for _ in range(args.num_accumulate):
                real_img = next(loader).to(device)
                real_img.requires_grad = True
                real_pred = discriminator(real_img)
                r1_loss = d_r1_penalty(real_img, real_pred, args)
                loss_dict["R1 Penalty"] += r1_loss.detach().squeeze()
                r1_loss = args.r1 * args.d_reg_every * r1_loss / args.num_accumulate
                r1_loss.backward()
            d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_augment += th.tensor(
                (th.sign(real_pred).sum().item(), real_pred.shape[0]),
                device=device)
            ada_augment = reduce_sum(ada_augment)

            if ada_augment[1] > 255:
                pred_signs, n_pred = ada_augment.tolist()

                r_t_stat = pred_signs / n_pred
                loss_dict["Rt"] = th.tensor(r_t_stat, device=device).float()
                if r_t_stat > args.ada_target:
                    sign = 1
                else:
                    sign = -1

                ada_aug_p += sign * ada_aug_step * n_pred
                ada_aug_p = min(1, max(0, ada_aug_p))
                ada_augment.mul_(0)
                loss_dict["Augment"] = th.tensor(ada_aug_p,
                                                 device=device).float()

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        generator.zero_grad()
        for _ in range(args.num_accumulate):
            noise = make_noise(args.batch_size, args.latent_size,
                               args.mixing_prob)
            fake_img, _ = generator(noise)
            if args.augment:
                fake_img, _ = augment(fake_img, ada_aug_p)
            fake_pred = discriminator(fake_img)
            g_loss = g_non_saturating_loss(fake_pred)
            loss_dict["Generator"] += g_loss.detach()
            g_loss /= args.num_accumulate
            g_loss.backward()
        g_optim.step()

        if args.path_regularize > 0 and i % args.g_reg_every == 0:
            generator.zero_grad()
            for _ in range(args.num_accumulate):
                path_loss, mean_path_length = g_path_length_regularization(
                    generator, mean_path_length, args)
                loss_dict["Path Length Regularization"] += path_loss.detach()
                path_loss = args.path_regularize * args.g_reg_every * path_loss / args.num_accumulate
                path_loss.backward()
            g_optim.step()

        accumulate(g_ema, g_module)

        loss_reduced = reduce_loss_dict(loss_dict)
        log_dict = {
            k: v.mean().item() / args.num_accumulate
            for k, v in loss_reduced.items() if v != 0
        }
        if get_rank() == 0:
            if args.log_spec_norm:
                G_norms = []
                for name, spec_norm in g_module.named_buffers():
                    if "spectral_norm" in name:
                        G_norms.append(spec_norm.cpu().numpy())
                G_norms = np.array(G_norms)
                D_norms = []
                for name, spec_norm in d_module.named_buffers():
                    if "spectral_norm" in name:
                        D_norms.append(spec_norm.cpu().numpy())
                D_norms = np.array(D_norms)
                log_dict[f"Spectral Norms/G min spectral norm"] = np.log(
                    G_norms).min()
                log_dict[f"Spectral Norms/G mean spectral norm"] = np.log(
                    G_norms).mean()
                log_dict[f"Spectral Norms/G max spectral norm"] = np.log(
                    G_norms).max()
                log_dict[f"Spectral Norms/D min spectral norm"] = np.log(
                    D_norms).min()
                log_dict[f"Spectral Norms/D mean spectral norm"] = np.log(
                    D_norms).mean()
                log_dict[f"Spectral Norms/D max spectral norm"] = np.log(
                    D_norms).max()

            if i % args.img_every == 0:
                gc.collect()
                th.cuda.empty_cache()
                with th.no_grad():
                    g_ema.eval()
                    sample = []
                    for sub in range(0, len(sample_z), args.batch_size):
                        subsample, _ = g_ema(
                            [sample_z[sub:sub + args.batch_size]])
                        sample.append(subsample.cpu())
                    sample = th.cat(sample)
                    grid = utils.make_grid(sample,
                                           nrow=10,
                                           normalize=True,
                                           range=(-1, 1))
                log_dict["Generated Images EMA"] = [
                    wandb.Image(grid, caption=f"Step {i}")
                ]

            if i % args.eval_every == 0:
                fid_dict = validation.fid(g_ema, args.val_batch_size,
                                          args.fid_n_sample,
                                          args.fid_truncation, args.name)

                fid = fid_dict["FID"]
                fids.append(fid)
                density = fid_dict["Density"]
                coverage = fid_dict["Coverage"]

                ppl = validation.ppl(
                    g_ema,
                    args.val_batch_size,
                    args.ppl_n_sample,
                    args.ppl_space,
                    args.ppl_crop,
                    args.latent_size,
                )

                log_dict["Evaluation/FID"] = fid
                log_dict["Sweep/FID_smooth"] = gaussian_filter(
                    np.array(fids), [5])[-1]
                log_dict["Evaluation/Density"] = density
                log_dict["Evaluation/Coverage"] = coverage
                log_dict["Evaluation/PPL"] = ppl

                gc.collect()
                th.cuda.empty_cache()

            wandb.log(log_dict)
            description = (
                f"FID: {fid:.4f}   PPL: {ppl:.4f}   Dens: {density:.4f}   Cov: {coverage:.4f}   "
                +
                f"G: {log_dict['Generator']:.4f}   D: {log_dict['Discriminator']:.4f}"
            )
            if "Augment" in log_dict:
                description += f"   Aug: {log_dict['Augment']:.4f}"  #   Rt: {log_dict['Rt']:.4f}"
            if "R1 Penalty" in log_dict:
                description += f"   R1: {log_dict['R1 Penalty']:.4f}"
            if "Path Length Regularization" in log_dict:
                description += f"   Path: {log_dict['Path Length Regularization']:.4f}"
            pbar.set_description(description)

            if i % args.checkpoint_every == 0:
                check_name = "-".join([
                    args.name,
                    args.runname,
                    wandb.run.dir.split("/")[-1].split("-")[-1],
                    int(fid),
                    args.size,
                    str(i).zfill(6),
                ])
                th.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        # "cl": cl_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                    },
                    f"/home/hans/modelzoo/maua-sg2/{check_name}.pt",
                )