def validate(epoch): if args.dynamic_channel: # we also evaluate the model with half channels set_uniform_channel_ratio(g_ema, 0.5) fid = measure_fid() reset_generator(g_ema) if hvd.rank() == 0: print(' * FID-0.5x: {:.2f}'.format(fid)) log_writer.add_scalar('Metrics/fid-0.5x', fid, len(data_loader) * (epoch + 1) * args.batch_size * hvd.size()) fid = measure_fid() if hvd.rank() == 0: log_writer.add_scalar('Metrics/fid', fid, len(data_loader) * (epoch + 1) * args.batch_size * hvd.size()) global best_fid best_fid = min(best_fid, fid) if hvd.rank() == 0: print(' * FID: {:.2f} ({:.2f})'.format(fid, best_fid)) state_dict = { "g": generator.state_dict(), "d": discriminator.state_dict(), "g_ema": g_ema.state_dict(), "g_optim": g_optim.state_dict(), "d_optim": d_optim.state_dict(), "epoch": epoch + 1, "fid": fid, "best_fid": best_fid, "mean_path_length": mean_path_length, } torch.save(state_dict, os.path.join(checkpoint_dir, args.job, 'ckpt.pt')) if best_fid == fid: torch.save(state_dict, os.path.join(checkpoint_dir, args.job, 'ckpt-best.pt'))
def closure(): optimizer.zero_grad() process_generator() out = generator(**input_kwargs)[0].clamp(-1, 1) reset_generator(generator) loss, loss_dict = compute_loss_sum(out, images, styles) loss_list.append(loss_dict) return loss
def project_images(images): with torch.no_grad(): if encoder is not None: styles = encoder(adaptive_resize(images, 256)) else: styles = generator.mean_style(10000).view(1, 1, -1).repeat( images.shape[0], generator.n_style, 1) init_styles = styles.detach().clone() input_kwargs = { 'styles': styles, 'noise': None, 'randomize_noise': False, 'input_is_style': True } styles.requires_grad = True # we only optimize the styles but not noise; with noise it is harder to manipulate if args.optimizer.lower() == "lbfgs": optimizer = LBFGS.FullBatchLBFGS([styles], lr=1) elif args.optimizer.lower() == "adam": optimizer = torch.optim.Adam([styles], lr=0.001) else: raise NotImplementedError with torch.no_grad(): init_image = generator(**input_kwargs)[0].clamp(-1, 1) loss, loss_dict = compute_loss_sum(init_image, images, styles) loss_list = [] loss_list.append(loss_dict) pbar = tqdm(range(args.n_iter)) for _ in pbar: if isinstance(optimizer, LBFGS.FullBatchLBFGS): def closure(): optimizer.zero_grad() process_generator() out = generator(**input_kwargs)[0].clamp(-1, 1) reset_generator(generator) loss, loss_dict = compute_loss_sum(out, images, styles) loss_list.append(loss_dict) return loss options = {'closure': closure, 'current_loss': loss, 'max_ls': 10} loss, grad, lr, _, _, _, _, _ = optimizer.step(options=options) else: process_generator() out = generator(**input_kwargs)[0] reset_generator(generator) loss, loss_dict = compute_loss_sum(out, images, styles) loss.backward() optimizer.step() loss_list.append(loss_dict) pbar.set_postfix(loss_list[-1]) return styles.detach()
def process_generator(): if args.optimize_sub_g: if evolve_cfgs is not None: # the generator is trained with elastic channels and evolved if random.random() < 0.5: # randomly pick an evolution config rand_cfg = random.sample(list(evolve_cfgs.keys())) set_sub_channel_config(generator, rand_cfg['channels']) generator.target_res = rand_cfg['res'] else: reset_generator(generator) # full G else: set_uniform_channel_ratio(generator, random.choice(CHANNEL_CONFIGS)) generator.target_res = random.choice([256, 512, 1024]) else: pass
def train(epoch): generator.train() discriminator.train() g_ema.eval() sampler.set_epoch(epoch) with tqdm(total=len(data_loader), desc='Epoch #{}'.format(epoch + 1), disable=hvd.rank() != 0, dynamic_ncols=True) as t: global mean_path_length # track across epochs ema_decay = 0.5 ** (args.batch_size * hvd.size() / (args.half_life_kimg * 1000.)) # loss meters d_loss_meter = DistributedMeter('d_loss') r1_loss_meter = DistributedMeter('r1_loss') g_loss_meter = DistributedMeter('g_loss') path_loss_meter = DistributedMeter('path_loss') d_real_acc = DistributedMeter('d_real_acc') d_fake_acc = DistributedMeter('d_fake_acc') distill_loss_meter = DistributedMeter('distill_loss') for batch_idx, real_img in enumerate(data_loader): global_idx = batch_idx + epoch * len(data_loader) + 1 if args.n_res > 1: real_img = [ri.to(device) for ri in real_img] # a stack of images else: real_img = real_img.to(device) # 1. train D requires_grad(generator, False) requires_grad(discriminator, True) z = get_mixing_z(args.batch_size, args.latent_dim, args.mixing_prob, device) with torch.no_grad(): if args.dynamic_channel: rand_ratio = sample_random_sub_channel( generator, min_channel=args.min_channel, divided_by=args.divided_by, mode=args.dynamic_channel_mode, ) fake_img, all_rgbs = generator(z, return_rgbs=True) all_rgbs = all_rgbs[-args.n_res:] reset_generator(generator) if args.n_res > 1: sampled_res = random.sample(all_resolutions, args.n_sampled_res) d_loss = 0. g_arch = get_g_arch(rand_ratio) if args.conditioned_d else None rand_g_arch = get_random_g_arch( # randomly draw one for real images generator, args.min_channel, args.divided_by, args.dynamic_channel_mode ) if args.conditioned_d else None for ri, fi in zip(real_img, all_rgbs): if ri.shape[-1] in sampled_res: real_pred = discriminator(ri, rand_g_arch) fake_pred = discriminator(fi, g_arch) d_loss += d_logistic_loss(real_pred, fake_pred) else: assert not args.conditioned_d # not implemented yet fake_pred = discriminator(fake_img) real_pred = discriminator(real_img) d_loss = d_logistic_loss(real_pred, fake_pred) d_real_acc.update((real_pred > 0).sum() * 1. / real_pred.shape[0]) d_fake_acc.update((fake_pred < 0).sum() * 1. / real_pred.shape[0]) d_loss_meter.update(d_loss) discriminator.zero_grad() d_loss.backward() d_optim.step() # reg D if args.d_reg_every > 0 and global_idx % args.d_reg_every == 0: reg_img = random.choice(real_img) if args.n_res > 1 else real_img reg_img.requires_grad = True if args.conditioned_d: real_pred = discriminator(reg_img, g_arch) else: real_pred = discriminator(reg_img) r1_loss = d_r1_loss(real_pred, reg_img) discriminator.zero_grad() (args.r1 / 2 * r1_loss * args.d_reg_every + 0 * real_pred[0]).backward() d_optim.step() r1_loss_meter.update(r1_loss) # 2. train G requires_grad(generator, True) requires_grad(discriminator, False) z = get_mixing_z(args.batch_size, args.latent_dim, args.mixing_prob, device) # fix the randomness (potentially apply distillation) noises = generator.make_noise() inject_index = None if z.shape[1] == 1 else random.randint(1, generator.n_style - 1) if args.dynamic_channel: rand_ratio = sample_random_sub_channel(generator, min_channel=args.min_channel, divided_by=args.divided_by, mode=args.dynamic_channel_mode) fake_img, all_rgbs = generator(z, noise=noises, inject_index=inject_index, return_rgbs=True) all_rgbs = all_rgbs[-args.n_res:] reset_generator(generator) # g loss if args.n_res > 1: sampled_rgbs = random.sample(all_rgbs, args.n_sampled_res) g_arch = get_g_arch(rand_ratio) if args.conditioned_d else None g_loss = sum([g_nonsaturating_loss(discriminator(r, g_arch)) for r in sampled_rgbs]) else: g_loss = g_nonsaturating_loss(discriminator(fake_img)) # distill loss if teacher is not None: with torch.no_grad(): teacher_out, _ = teacher(z, noise=noises, inject_index=inject_index) teacher_rgbs = get_teacher_multi_res(teacher_out, args.n_res) distill_loss1 = sum([nn.MSELoss()(sr, tr) for sr, tr in zip(all_rgbs, teacher_rgbs)]) distill_loss2 = sum([percept(adaptive_downsample256(sr), adaptive_downsample256(tr)).mean() for sr, tr in zip(all_rgbs, teacher_rgbs)]) distill_loss = distill_loss1 + distill_loss2 g_loss = g_loss + distill_loss * args.distill_loss_alpha distill_loss_meter.update(distill_loss * args.distill_loss_alpha) g_loss_meter.update(g_loss) generator.zero_grad() g_loss.backward() g_optim.step() # reg G if args.g_reg_every > 0 and global_idx % args.g_reg_every == 0: # path len reg assert args.n_res == 1 # currently, we do not apply path reg after the original StyleGAN training path_batch_size = max(1, args.batch_size // args.path_batch_shrink) noise = get_mixing_z(path_batch_size, args.latent_dim, args.mixing_prob, device) fake_img, latents = generator(noise, return_styles=True) # moving update the mean path length path_loss, mean_path_length, path_lengths = g_path_regularize( fake_img, latents, mean_path_length ) generator.zero_grad() weighted_path_loss = args.path_regularize * args.g_reg_every * path_loss # special trick to trigger sync gradient descent TODO: do we need it here? weighted_path_loss += 0 * fake_img[0, 0, 0, 0] weighted_path_loss.backward() g_optim.step() mean_path_length = hvd.allreduce(torch.Tensor([mean_path_length])).item() # update across gpus path_loss_meter.update(path_loss) # moving update accumulate(g_ema, generator, ema_decay) info2display = { 'd': d_loss_meter.avg.item(), 'g': g_loss_meter.avg.item(), 'r1': r1_loss_meter.avg.item(), 'd_real_acc': d_real_acc.avg.item(), 'd_fake_acc': d_fake_acc.avg.item() } if teacher is not None: info2display['dist'] = distill_loss_meter.avg.item() if args.g_reg_every > 0: info2display['path'] = path_loss_meter.avg.item() info2display['path-len'] = mean_path_length t.set_postfix(info2display) t.update(1) if hvd.rank() == 0 and global_idx % args.log_every == 0: n_trained_images = global_idx * args.batch_size * hvd.size() log_writer.add_scalar('Loss/D', d_loss_meter.avg.item(), n_trained_images) log_writer.add_scalar('Loss/G', g_loss_meter.avg.item(), n_trained_images) log_writer.add_scalar('Loss/r1', r1_loss_meter.avg.item(), n_trained_images) log_writer.add_scalar('Loss/path', path_loss_meter.avg.item(), n_trained_images) log_writer.add_scalar('Loss/path-len', mean_path_length, n_trained_images) log_writer.add_scalar('Loss/distill', distill_loss_meter.avg.item(), n_trained_images) if hvd.rank() == 0 and global_idx % args.log_vis_every == 0: # log image with torch.no_grad(): g_ema.eval() mean_style = g_ema.mean_style(10000) sample, _ = g_ema(sample_z, truncation=args.vis_truncation, truncation_style=mean_style) n_trained_images = global_idx * args.batch_size * hvd.size() grid = utils.make_grid(sample, nrow=int(args.n_vis_sample ** 0.5), normalize=True, range=(-1, 1)) log_writer.add_image('images', grid, n_trained_images)