Пример #1
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device):
    loader = sample_data(loader)

    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_z = torch.randn(args.n_sample, args.latent, device=device)

    for idx in pbar:
        i = idx + args.start_iter

        if i > args.iter:
            print("Done!")

            break

        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        noise = mixing_noise(args.batch, args.latent, args.mixing, device)
        fake_img, _ = generator(noise)

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)

        loss_dict["g"] = g_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)
            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            fake_img, latents = generator(noise, return_latents=True)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema, g_module, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:05d}; d: {d_loss_val:.4f}; g: {g_loss_val:.4f}; r1: {r1_val:.4f}; "
                f"path: {path_loss_val:.4f}; mean path: {mean_path_length_avg:.4f}; "
                f"augment: {ada_aug_p:.4f}"))

            if wandb and args.wandb:
                wandb.log({
                    "Generator": g_loss_val,
                    "Discriminator": d_loss_val,
                    "Augment": ada_aug_p,
                    "Rt": r_t_stat,
                    "R1": r1_val,
                    "Path Length Regularization": path_loss_val,
                    "Mean Path Length": mean_path_length,
                    "Real Score": real_score_val,
                    "Fake Score": fake_score_val,
                    "Path Length": path_length_val,
                })

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_z])
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/finetune-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if (i + 1) % args.save_every == 0 or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/fintune-%06d.pt" %
                    (args.model_path, args.style, i + 1),
                )
Пример #2
0
def train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          instyles, Simgs, exstyles, vggloss, id_loss, device):
    loader = sample_data(loader)
    vgg_weights = [0.0, 0.5, 1.0, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    smoothing=0.01,
                    ncols=180,
                    dynamic_ncols=False)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_instyle = torch.randn(args.n_sample, args.latent, device=device)
    sample_exstyle, _, _ = get_paired_data(instyles,
                                           Simgs,
                                           exstyles,
                                           batch_size=args.n_sample,
                                           random_ind=8)
    sample_exstyle = sample_exstyle.to(device)

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq  # defines whether we use paired data

        if i > args.iter:
            print("Done!")
            break

        # sample S
        real_img = next(loader)
        real_img = real_img.to(device)

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=args.batch,
                                            random_ind=8)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred)

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which == 0:
            # sample z^+_e, z for Lsty, Lcon and Ladv
            exstyle, _, real_img = get_paired_data(instyles,
                                                   Simgs,
                                                   exstyles,
                                                   batch_size=args.batch,
                                                   random_ind=8)
            real_img = real_img.to(device)
            exstyle = exstyle.to(device)
            instyle = mixing_noise(args.batch, args.latent, args.mixing,
                                   device)
            z_plus_latent = False
        else:
            # sample z^+_e, z^+_i and S for Eq. (4)
            exstyle, instyle, real_img = get_paired_data(instyles,
                                                         Simgs,
                                                         exstyles,
                                                         batch_size=args.batch,
                                                         random_ind=8)
            exstyle = exstyle.to(device)
            instyle = [instyle.to(device)]
            real_img = real_img.to(device)
            z_plus_latent = True

        fake_img, _ = generator(instyle,
                                exstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        with torch.no_grad():
            real_img_256 = F.adaptive_avg_pool2d(real_img, 256).detach()
            real_feats = vggloss(real_img_256)
            real_styles = [
                F.adaptive_avg_pool2d(real_feat, output_size=1).detach()
                for real_feat in real_feats
            ]
            real_content, _ = generator(instyle,
                                        None,
                                        use_res=False,
                                        z_plus_latent=z_plus_latent)
            real_content_256 = F.adaptive_avg_pool2d(real_content,
                                                     256).detach()

        fake_img_256 = F.adaptive_avg_pool2d(fake_img, 256)
        fake_feats = vggloss(fake_img_256)
        fake_styles = [
            F.adaptive_avg_pool2d(fake_feat, output_size=1)
            for fake_feat in fake_feats
        ]
        sty_loss = (torch.tensor(0.0).to(device) if args.CX_loss == 0 else
                    FCX.contextual_loss(fake_feats[2],
                                        real_feats[2].detach(),
                                        band_width=0.2,
                                        loss_type='cosine') * args.CX_loss)
        if args.style_loss > 0:
            sty_loss += ((F.mse_loss(fake_styles[1], real_styles[1]) +
                          F.mse_loss(fake_styles[2], real_styles[2])) *
                         args.style_loss)

        ID_loss = (torch.tensor(0.0).to(device) if args.id_loss == 0 else
                   id_loss(fake_img_256, real_content_256) * args.id_loss)

        gr_loss = torch.tensor(0.0).to(device)
        if which > 0:
            for ii, weight in enumerate(vgg_weights):
                if weight * args.perc_loss > 0:
                    gr_loss += F.l1_loss(
                        fake_feats[ii],
                        real_feats[ii].detach()) * weight * args.perc_loss

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred)
        l2_reg_loss = sum(
            torch.norm(p)
            for p in g_module.res.parameters()) * args.L2_reg_loss

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # Lperc
        loss_dict["l2"] = l2_reg_loss  # Lreg in Lcon
        loss_dict["id"] = ID_loss  # LID in Lcon
        loss_dict["sty"] = sty_loss  # Lsty
        g_loss = g_loss + gr_loss + sty_loss + l2_reg_loss + ID_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            instyle = mixing_noise(path_batch_size, args.latent, args.mixing,
                                   device)
            exstyle, _, _ = get_paired_data(instyles,
                                            Simgs,
                                            exstyles,
                                            batch_size=path_batch_size,
                                            random_ind=8)
            exstyle = exstyle.to(device)

            fake_img, latents = generator(instyle,
                                          exstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        sty_loss_val = loss_reduced["sty"].mean().item()
        l2_loss_val = loss_reduced["l2"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        id_loss_val = loss_reduced["id"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; sty: {sty_loss_val:.3f}; l2: {l2_loss_val:.3f}; id: {id_loss_val:.3f}; "
                f"r1: {r1_val:.3f}; path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.4f};"))

            if i % 100 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([sample_instyle],
                                      sample_exstyle,
                                      use_res=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s/dualstylegan-%06d.jpg" % (args.style, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if ((i + 1) >= args.save_begin and
                (i + 1) % args.save_every == 0) or (i + 1) == args.iter:
                torch.save(
                    {
                        #"g": g_module.state_dict(),
                        #"d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        #"g_optim": g_optim.state_dict(),
                        #"d_optim": d_optim.state_dict(),
                        #"args": args,
                        #"ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s/%s-%06d.pt" %
                    (args.model_path, args.style, args.model_name, i + 1),
                )
Пример #3
0
        )

        discriminator = nn.parallel.DistributedDataParallel(
            discriminator,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            broadcast_buffers=False,
        )

    transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True),
    ])

    dataset = MultiResolutionDataset(args.path, transform, args.size)
    loader = data.DataLoader(
        dataset,
        batch_size=args.batch,
        sampler=data_sampler(dataset,
                             shuffle=True,
                             distributed=args.distributed),
        drop_last=True,
    )

    if get_rank() == 0 and wandb is not None and args.wandb:
        wandb.init(project="stylegan 2")

    train(args, loader, generator, discriminator, g_optim, d_optim, g_ema,
          device)
def pretrain(args,
             loader,
             generator,
             discriminator,
             g_optim,
             d_optim,
             g_ema,
             encoder,
             vggloss,
             device,
             inject_index=5,
             savemodel=True):
    loader = sample_data(loader)
    vgg_weights = [0.5, 0.5, 0.5, 0.0, 0.0]
    pbar = range(args.iter)

    if get_rank() == 0:
        pbar = tqdm(pbar,
                    initial=args.start_iter,
                    ncols=140,
                    dynamic_ncols=False,
                    smoothing=0.01)

    mean_path_length = 0

    d_loss_val = 0
    r1_loss = torch.tensor(0.0, device=device)
    g_loss_val = 0
    path_loss = torch.tensor(0.0, device=device)
    path_lengths = torch.tensor(0.0, device=device)
    mean_path_length_avg = 0
    loss_dict = {}

    if args.distributed:
        g_module = generator.module
        d_module = discriminator.module

    else:
        g_module = generator
        d_module = discriminator

    accum = 0.5**(32 / (10 * 1000))
    ada_aug_p = args.augment_p if args.augment_p > 0 else 0.0
    r_t_stat = 0

    if args.augment and args.augment_p == 0:
        ada_augment = AdaptiveAugment(args.ada_target, args.ada_length, 8,
                                      device)

    sample_zs = mixing_noise(args.n_sample, args.latent, 1.0, device)
    with torch.no_grad():
        source_img, _ = generator([sample_zs[0]],
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  use_res=False)
        source_img = source_img.detach()
        target_img, _ = generator(sample_zs,
                                  None,
                                  input_is_latent=False,
                                  z_plus_latent=False,
                                  inject_index=inject_index,
                                  use_res=False)
        target_img = target_img.detach()
        style_img, _ = generator([sample_zs[1]],
                                 None,
                                 input_is_latent=False,
                                 z_plus_latent=False,
                                 use_res=False)
        _, sample_style = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
        sample_style = sample_style.detach()
        if get_rank() == 0:
            utils.save_image(F.adaptive_avg_pool2d(source_img, 256),
                             f"log/%s-instyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(target_img, 256),
                             f"log/%s-target.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))
            utils.save_image(F.adaptive_avg_pool2d(style_img, 256),
                             f"log/%s-exstyle.jpg" % (args.model_name),
                             nrow=int(args.n_sample**0.5),
                             normalize=True,
                             range=(-1, 1))

    for idx in pbar:
        i = idx + args.start_iter

        which = i % args.subspace_freq

        if i > args.iter:
            print("Done!")
            break

        real_img = next(loader)
        real_img = real_img.to(device)

        # real_zs contains z1 and z2
        real_zs = mixing_noise(args.batch, args.latent, 1.0, device)
        with torch.no_grad():
            # g(z^+_l) with l=inject_index
            target_img, _ = generator(real_zs,
                                      None,
                                      input_is_latent=False,
                                      z_plus_latent=False,
                                      inject_index=inject_index,
                                      use_res=False)
            target_img = target_img.detach()
            # g(z2)
            style_img, _ = generator([real_zs[1]],
                                     None,
                                     input_is_latent=False,
                                     z_plus_latent=False,
                                     use_res=False)
            style_img = style_img.detach()
            # E(g(z2))
            _, pspstyle = encoder(F.adaptive_avg_pool2d(style_img, 256),
                                  randomize_noise=False,
                                  return_latents=True,
                                  z_plus_latent=True,
                                  return_z_plus_latent=False)
            pspstyle = pspstyle.detach()

        requires_grad(generator, False)
        requires_grad(discriminator, True)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        if args.augment:
            real_img_aug, _ = augment(real_img, ada_aug_p)
            fake_img, _ = augment(fake_img, ada_aug_p)

        else:
            real_img_aug = real_img

        fake_pred = discriminator(fake_img)
        real_pred = discriminator(real_img_aug)
        d_loss = d_logistic_loss(real_pred, fake_pred) * 0.1

        loss_dict["d"] = d_loss  # Ladv
        loss_dict["real_score"] = real_pred.mean()
        loss_dict["fake_score"] = fake_pred.mean()

        discriminator.zero_grad()
        d_loss.backward()
        d_optim.step()

        if args.augment and args.augment_p == 0:
            ada_aug_p = ada_augment.tune(real_pred)
            r_t_stat = ada_augment.r_t_stat

        d_regularize = i % args.d_reg_every == 0

        if d_regularize:
            real_img.requires_grad = True

            if args.augment:
                real_img_aug, _ = augment(real_img, ada_aug_p)

            else:
                real_img_aug = real_img

            real_pred = discriminator(real_img_aug)
            r1_loss = d_r1_loss(real_pred, real_img)

            discriminator.zero_grad()
            (args.r1 / 2 * r1_loss * args.d_reg_every +
             0 * real_pred[0]).backward()

            d_optim.step()

        loss_dict["r1"] = r1_loss

        requires_grad(generator, True)
        requires_grad(discriminator, False)

        if which > 0:
            # set z~_2 = z2
            noise = [real_zs[0]]
            externalstyle = g_module.get_latent(real_zs[1]).detach()
            z_plus_latent = False
        else:
            # set z~_2 = E(g(z2))
            noise = [real_zs[0].unsqueeze(1).repeat(1, g_module.n_latent, 1)]
            externalstyle = pspstyle
            z_plus_latent = True

        fake_img, _ = generator(noise,
                                externalstyle,
                                use_res=True,
                                z_plus_latent=z_plus_latent)

        real_feats = vggloss(F.adaptive_avg_pool2d(target_img, 256).detach())
        fake_feats = vggloss(F.adaptive_avg_pool2d(fake_img, 256))
        gr_loss = torch.tensor(0.0).to(device)
        for ii, weight in enumerate(vgg_weights):
            if weight > 0:
                gr_loss += F.l1_loss(fake_feats[ii],
                                     real_feats[ii].detach()) * weight

        if args.augment:
            fake_img, _ = augment(fake_img, ada_aug_p)

        fake_pred = discriminator(fake_img)
        g_loss = g_nonsaturating_loss(fake_pred) * 0.1

        loss_dict["g"] = g_loss  # Ladv
        loss_dict["gr"] = gr_loss  # L_perc

        g_loss += gr_loss

        generator.zero_grad()
        g_loss.backward()
        g_optim.step()

        g_regularize = i % args.g_reg_every == 0

        if g_regularize:
            path_batch_size = max(1, args.batch // args.path_batch_shrink)

            noise = mixing_noise(path_batch_size, args.latent, args.mixing,
                                 device)
            externalstyle = torch.randn(path_batch_size, 512, device=device)
            externalstyle = g_module.get_latent(externalstyle).detach()
            fake_img, latents = generator(noise,
                                          externalstyle,
                                          return_latents=True,
                                          use_res=True,
                                          z_plus_latent=False)

            path_loss, mean_path_length, path_lengths = g_path_regularize(
                fake_img, latents, mean_path_length)

            generator.zero_grad()
            weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss

            if args.path_batch_shrink:
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]

            weighted_path_loss.backward()

            g_optim.step()

            mean_path_length_avg = (reduce_sum(mean_path_length).item() /
                                    get_world_size())

        loss_dict["path"] = path_loss
        loss_dict["path_length"] = path_lengths.mean()

        accumulate(g_ema.res, g_module.res, accum)

        loss_reduced = reduce_loss_dict(loss_dict)

        d_loss_val = loss_reduced["d"].mean().item()
        g_loss_val = loss_reduced["g"].mean().item()
        gr_loss_val = loss_reduced["gr"].mean().item()
        r1_val = loss_reduced["r1"].mean().item()
        path_loss_val = loss_reduced["path"].mean().item()
        real_score_val = loss_reduced["real_score"].mean().item()
        fake_score_val = loss_reduced["fake_score"].mean().item()
        path_length_val = loss_reduced["path_length"].mean().item()

        if get_rank() == 0:
            pbar.set_description((
                f"iter: {i:d}; d: {d_loss_val:.3f}; g: {g_loss_val:.3f}; gr: {gr_loss_val:.3f}; r1: {r1_val:.3f}; "
                f"path: {path_loss_val:.3f}; mean path: {mean_path_length_avg:.3f}; "
                f"augment: {ada_aug_p:.1f}"))

            if i % 300 == 0 or (i + 1) == args.iter:
                with torch.no_grad():
                    g_ema.eval()
                    sample, _ = g_ema([
                        sample_zs[0].unsqueeze(1).repeat(
                            1, g_module.n_latent, 1)
                    ],
                                      sample_style,
                                      use_res=True,
                                      z_plus_latent=True)
                    sample = F.interpolate(sample, 256)
                    utils.save_image(
                        sample,
                        f"log/%s-%06d.jpg" % (args.model_name, i),
                        nrow=int(args.n_sample**0.5),
                        normalize=True,
                        range=(-1, 1),
                    )

            if savemodel and ((i + 1) % args.save_every == 0 or
                              (i + 1) == args.iter):
                torch.save(
                    {
                        "g": g_module.state_dict(),
                        "d": d_module.state_dict(),
                        "g_ema": g_ema.state_dict(),
                        "g_optim": g_optim.state_dict(),
                        "d_optim": d_optim.state_dict(),
                        "args": args,
                        "ada_aug_p": ada_aug_p,
                    },
                    f"%s/%s-%06d.pt" %
                    (args.model_path, args.model_name, i + 1),
                )