Example #1
0
def validate(epoch):
    if args.dynamic_channel:  # we also evaluate the model with half channels
        set_uniform_channel_ratio(g_ema, 0.5)
        fid = measure_fid()
        reset_generator(g_ema)
        if hvd.rank() == 0:
            print(' * FID-0.5x: {:.2f}'.format(fid))
            log_writer.add_scalar('Metrics/fid-0.5x', fid,
                                  len(data_loader) * (epoch + 1) * args.batch_size * hvd.size())

    fid = measure_fid()
    if hvd.rank() == 0:
        log_writer.add_scalar('Metrics/fid', fid, len(data_loader) * (epoch + 1) * args.batch_size * hvd.size())
    global best_fid
    best_fid = min(best_fid, fid)
    if hvd.rank() == 0:
        print(' * FID: {:.2f} ({:.2f})'.format(fid, best_fid))
        state_dict = {
            "g": generator.state_dict(),
            "d": discriminator.state_dict(),
            "g_ema": g_ema.state_dict(),
            "g_optim": g_optim.state_dict(),
            "d_optim": d_optim.state_dict(),
            "epoch": epoch + 1,
            "fid": fid,
            "best_fid": best_fid,
            "mean_path_length": mean_path_length,
        }
        torch.save(state_dict, os.path.join(checkpoint_dir, args.job, 'ckpt.pt'))
        if best_fid == fid:
            torch.save(state_dict, os.path.join(checkpoint_dir, args.job, 'ckpt-best.pt'))
Example #2
0
 def closure():
     optimizer.zero_grad()
     process_generator()
     out = generator(**input_kwargs)[0].clamp(-1, 1)
     reset_generator(generator)
     loss, loss_dict = compute_loss_sum(out, images, styles)
     loss_list.append(loss_dict)
     return loss
Example #3
0
def project_images(images):
    with torch.no_grad():
        if encoder is not None:
            styles = encoder(adaptive_resize(images, 256))
        else:
            styles = generator.mean_style(10000).view(1, 1, -1).repeat(
                images.shape[0], generator.n_style, 1)

    init_styles = styles.detach().clone()
    input_kwargs = {
        'styles': styles,
        'noise': None,
        'randomize_noise': False,
        'input_is_style': True
    }
    styles.requires_grad = True

    # we only optimize the styles but not noise; with noise it is harder to manipulate
    if args.optimizer.lower() == "lbfgs":
        optimizer = LBFGS.FullBatchLBFGS([styles], lr=1)
    elif args.optimizer.lower() == "adam":
        optimizer = torch.optim.Adam([styles], lr=0.001)
    else:
        raise NotImplementedError

    with torch.no_grad():
        init_image = generator(**input_kwargs)[0].clamp(-1, 1)
        loss, loss_dict = compute_loss_sum(init_image, images, styles)

    loss_list = []
    loss_list.append(loss_dict)
    pbar = tqdm(range(args.n_iter))
    for _ in pbar:
        if isinstance(optimizer, LBFGS.FullBatchLBFGS):

            def closure():
                optimizer.zero_grad()
                process_generator()
                out = generator(**input_kwargs)[0].clamp(-1, 1)
                reset_generator(generator)
                loss, loss_dict = compute_loss_sum(out, images, styles)
                loss_list.append(loss_dict)
                return loss

            options = {'closure': closure, 'current_loss': loss, 'max_ls': 10}
            loss, grad, lr, _, _, _, _, _ = optimizer.step(options=options)
        else:
            process_generator()
            out = generator(**input_kwargs)[0]
            reset_generator(generator)
            loss, loss_dict = compute_loss_sum(out, images, styles)
            loss.backward()
            optimizer.step()
            loss_list.append(loss_dict)
        pbar.set_postfix(loss_list[-1])
    return styles.detach()
Example #4
0
def process_generator():
    if args.optimize_sub_g:
        if evolve_cfgs is not None:  # the generator is trained with elastic channels and evolved
            if random.random() < 0.5:  # randomly pick an evolution config
                rand_cfg = random.sample(list(evolve_cfgs.keys()))
                set_sub_channel_config(generator, rand_cfg['channels'])
                generator.target_res = rand_cfg['res']
            else:
                reset_generator(generator)  # full G
        else:
            set_uniform_channel_ratio(generator,
                                      random.choice(CHANNEL_CONFIGS))
            generator.target_res = random.choice([256, 512, 1024])
    else:
        pass
Example #5
0
def train(epoch):
    generator.train()
    discriminator.train()
    g_ema.eval()
    sampler.set_epoch(epoch)

    with tqdm(total=len(data_loader),
              desc='Epoch #{}'.format(epoch + 1),
              disable=hvd.rank() != 0, dynamic_ncols=True) as t:
        global mean_path_length  # track across epochs

        ema_decay = 0.5 ** (args.batch_size * hvd.size() / (args.half_life_kimg * 1000.))

        # loss meters
        d_loss_meter = DistributedMeter('d_loss')
        r1_loss_meter = DistributedMeter('r1_loss')
        g_loss_meter = DistributedMeter('g_loss')
        path_loss_meter = DistributedMeter('path_loss')
        d_real_acc = DistributedMeter('d_real_acc')
        d_fake_acc = DistributedMeter('d_fake_acc')
        distill_loss_meter = DistributedMeter('distill_loss')

        for batch_idx, real_img in enumerate(data_loader):
            global_idx = batch_idx + epoch * len(data_loader) + 1
            if args.n_res > 1:
                real_img = [ri.to(device) for ri in real_img]  # a stack of images
            else:
                real_img = real_img.to(device)

            # 1. train D
            requires_grad(generator, False)
            requires_grad(discriminator, True)

            z = get_mixing_z(args.batch_size, args.latent_dim, args.mixing_prob, device)
            with torch.no_grad():
                if args.dynamic_channel:
                    rand_ratio = sample_random_sub_channel(
                        generator, min_channel=args.min_channel,
                        divided_by=args.divided_by,
                        mode=args.dynamic_channel_mode,
                    )
                fake_img, all_rgbs = generator(z, return_rgbs=True)
                all_rgbs = all_rgbs[-args.n_res:]
                reset_generator(generator)

            if args.n_res > 1:
                sampled_res = random.sample(all_resolutions, args.n_sampled_res)
                d_loss = 0.
                g_arch = get_g_arch(rand_ratio) if args.conditioned_d else None
                rand_g_arch = get_random_g_arch(  # randomly draw one for real images
                    generator, args.min_channel, args.divided_by, args.dynamic_channel_mode
                ) if args.conditioned_d else None
                for ri, fi in zip(real_img, all_rgbs):
                    if ri.shape[-1] in sampled_res:
                        real_pred = discriminator(ri, rand_g_arch)
                        fake_pred = discriminator(fi, g_arch)
                        d_loss += d_logistic_loss(real_pred, fake_pred)
            else:
                assert not args.conditioned_d  # not implemented yet
                fake_pred = discriminator(fake_img)
                real_pred = discriminator(real_img)
                d_loss = d_logistic_loss(real_pred, fake_pred)

            d_real_acc.update((real_pred > 0).sum() * 1. / real_pred.shape[0])
            d_fake_acc.update((fake_pred < 0).sum() * 1. / real_pred.shape[0])
            d_loss_meter.update(d_loss)

            discriminator.zero_grad()
            d_loss.backward()
            d_optim.step()

            # reg D
            if args.d_reg_every > 0 and global_idx % args.d_reg_every == 0:
                reg_img = random.choice(real_img) if args.n_res > 1 else real_img
                reg_img.requires_grad = True

                if args.conditioned_d:
                    real_pred = discriminator(reg_img, g_arch)
                else:
                    real_pred = discriminator(reg_img)
                r1_loss = d_r1_loss(real_pred, reg_img)

                discriminator.zero_grad()
                (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward()
                d_optim.step()
                r1_loss_meter.update(r1_loss)

            # 2. train G
            requires_grad(generator, True)
            requires_grad(discriminator, False)

            z = get_mixing_z(args.batch_size, args.latent_dim, args.mixing_prob, device)
            # fix the randomness (potentially apply distillation)
            noises = generator.make_noise()
            inject_index = None if z.shape[1] == 1 else random.randint(1, generator.n_style - 1)

            if args.dynamic_channel:
                rand_ratio = sample_random_sub_channel(generator, min_channel=args.min_channel,
                                                       divided_by=args.divided_by,
                                                       mode=args.dynamic_channel_mode)
            fake_img, all_rgbs = generator(z, noise=noises, inject_index=inject_index, return_rgbs=True)
            all_rgbs = all_rgbs[-args.n_res:]
            reset_generator(generator)

            # g loss
            if args.n_res > 1:
                sampled_rgbs = random.sample(all_rgbs, args.n_sampled_res)
                g_arch = get_g_arch(rand_ratio) if args.conditioned_d else None
                g_loss = sum([g_nonsaturating_loss(discriminator(r, g_arch)) for r in sampled_rgbs])
            else:
                g_loss = g_nonsaturating_loss(discriminator(fake_img))

            # distill loss
            if teacher is not None:
                with torch.no_grad():
                    teacher_out, _ = teacher(z, noise=noises, inject_index=inject_index)
                teacher_rgbs = get_teacher_multi_res(teacher_out, args.n_res)
                distill_loss1 = sum([nn.MSELoss()(sr, tr) for sr, tr in zip(all_rgbs, teacher_rgbs)])
                distill_loss2 = sum([percept(adaptive_downsample256(sr), adaptive_downsample256(tr)).mean()
                                     for sr, tr in zip(all_rgbs, teacher_rgbs)])
                distill_loss = distill_loss1 + distill_loss2
                g_loss = g_loss + distill_loss * args.distill_loss_alpha
                distill_loss_meter.update(distill_loss * args.distill_loss_alpha)

            g_loss_meter.update(g_loss)

            generator.zero_grad()
            g_loss.backward()
            g_optim.step()

            # reg G
            if args.g_reg_every > 0 and global_idx % args.g_reg_every == 0:  # path len reg
                assert args.n_res == 1  # currently, we do not apply path reg after the original StyleGAN training
                path_batch_size = max(1, args.batch_size // args.path_batch_shrink)
                noise = get_mixing_z(path_batch_size, args.latent_dim, args.mixing_prob, device)
                fake_img, latents = generator(noise, return_styles=True)
                # moving update the mean path length
                path_loss, mean_path_length, path_lengths = g_path_regularize(
                    fake_img, latents, mean_path_length
                )
                generator.zero_grad()
                weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss
                # special trick to trigger sync gradient descent TODO: do we need it here?
                weighted_path_loss += 0 * fake_img[0, 0, 0, 0]
                weighted_path_loss.backward()
                g_optim.step()
                mean_path_length = hvd.allreduce(torch.Tensor([mean_path_length])).item()  # update across gpus
                path_loss_meter.update(path_loss)

            # moving update
            accumulate(g_ema, generator, ema_decay)

            info2display = {
                'd': d_loss_meter.avg.item(),
                'g': g_loss_meter.avg.item(),
                'r1': r1_loss_meter.avg.item(),
                'd_real_acc': d_real_acc.avg.item(),
                'd_fake_acc': d_fake_acc.avg.item()
            }
            if teacher is not None:
                info2display['dist'] = distill_loss_meter.avg.item()
            if args.g_reg_every > 0:
                info2display['path'] = path_loss_meter.avg.item()
                info2display['path-len'] = mean_path_length

            t.set_postfix(info2display)
            t.update(1)

            if hvd.rank() == 0 and global_idx % args.log_every == 0:
                n_trained_images = global_idx * args.batch_size * hvd.size()
                log_writer.add_scalar('Loss/D', d_loss_meter.avg.item(), n_trained_images)
                log_writer.add_scalar('Loss/G', g_loss_meter.avg.item(), n_trained_images)
                log_writer.add_scalar('Loss/r1', r1_loss_meter.avg.item(), n_trained_images)
                log_writer.add_scalar('Loss/path', path_loss_meter.avg.item(), n_trained_images)
                log_writer.add_scalar('Loss/path-len', mean_path_length, n_trained_images)
                log_writer.add_scalar('Loss/distill', distill_loss_meter.avg.item(), n_trained_images)

            if hvd.rank() == 0 and global_idx % args.log_vis_every == 0:  # log image
                with torch.no_grad():
                    g_ema.eval()
                    mean_style = g_ema.mean_style(10000)
                    sample, _ = g_ema(sample_z, truncation=args.vis_truncation, truncation_style=mean_style)
                    n_trained_images = global_idx * args.batch_size * hvd.size()
                    grid = utils.make_grid(sample, nrow=int(args.n_vis_sample ** 0.5), normalize=True,
                                           range=(-1, 1))
                    log_writer.add_image('images', grid, n_trained_images)