def __init__(self): super(VGGLoss, self).__init__() self.vgg = VGG19() self.vgg.eval() util.set_requires_grad(self.vgg, False) self.criterion = nn.L1Loss() self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0]
def optimize_parameters(self, steps): self.optimizer_D.zero_grad() self.optimizer_G.zero_grad() self.forward() util.set_requires_grad(self.netD, True) self.backward_D() util.set_requires_grad(self.netD, False) self.backward_G() self.optimizer_D.step() self.optimizer_G.step()
def optimize_parameters(self): self.optimizer_D.zero_grad() self.optimizer_G.zero_grad() config = self.configs.sample() self.forward(config=config) util.set_requires_grad(self.netD, True) self.backward_D() util.set_requires_grad(self.netD, False) self.backward_G() self.optimizer_D.step() self.optimizer_G.step()
def optimize(opt): dataset_name = 'cat' generator_name = 'stylegan2' transform = data.get_transform(dataset_name, 'im2tensor') dset = data.get_dataset(dataset_name, opt.partition, load_w=False, transform=transform) total = len(dset) if opt.indices is None: start_idx = 0 end_idx = total else: start_idx = opt.indices[0] end_idx = opt.indices[1] print("Optimizing dataset partition %s items %d to %d" % (opt.partition, start_idx, end_idx)) generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=True) util.set_requires_grad(False, generator.generator) util.set_requires_grad(False, generator.encoder) for i in range(start_idx, end_idx): (im, label, path) = dset[i] img_filename = os.path.splitext(os.path.basename(path))[0] print("Running %d / %d images: %s" % (i, end_idx, img_filename)) output_filename = os.path.join(opt.w_path, img_filename) if os.path.isfile(output_filename + '.pth'): print(output_filename + '.pth found... skipping') continue # cat face dataset is already centered centered_im = im[None].cuda() # find zero values to estimate the mask mask = torch.ones_like(centered_im) mask[torch.where( torch.sum(torch.abs(centered_im), axis=0, keepdims=True) < 0.02 )] = 0 mask = mask[:, :1, :, :].cuda() ckpt, loss = generator.optimize(centered_im, mask=mask) w_optimized = ckpt['current_z'] loss = np.array(loss).squeeze() im_optimized = renormalize.as_image(ckpt['current_x'][0]) torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth') np.savez(output_filename + '_loss.npz', loss=loss) im_optimized.save(output_filename + '_optimized_im.png')
def optimize_parameters(self, steps): need_style_encoder = False if self.opt.student_no_style_encoder \ else steps % self.opt.style_encoder_step != 0 self.optimizer_D.zero_grad() self.optimizer_G.zero_grad() config = self.configs.sample() self.forward(config=config, need_style_encoder=need_style_encoder) util.set_requires_grad(self.netD, True) self.backward_D() util.set_requires_grad(self.netD, False) self.backward_G(need_style_encoder=need_style_encoder) self.optimizer_D.step() self.optimizer_G.step()
def optimize(opt): dataset_name = 'celebahq' generator_name = 'stylegan2' transform = data.get_transform(dataset_name, 'imval') # we don't need the labels, so attribute doesn't really matter here dset = data.get_dataset(dataset_name, opt.partition, 'Smiling', load_w=False, return_path=True, transform=transform) total = len(dset) if opt.indices is None: start_idx = 0 end_idx = total else: start_idx = opt.indices[0] end_idx = opt.indices[1] print("Optimizing dataset partition %s items %d to %d" % (opt.partition, start_idx, end_idx)) generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=True) util.set_requires_grad(False, generator.generator) util.set_requires_grad(False, generator.encoder) for i in range(start_idx, end_idx): (image, label, path) = dset[i] image = image[None].cuda() img_filename = os.path.splitext(os.path.basename(path))[0] print("Running %d / %d images: %s" % (i, end_idx, img_filename)) output_filename = os.path.join(opt.w_path, img_filename) if os.path.isfile(output_filename + '.pth'): print(output_filename + '.pth found... skipping') continue ckpt, loss = generator.optimize(image, mask=None) w_optimized = ckpt['current_z'] loss = np.array(loss).squeeze() im_optimized = renormalize.as_image(ckpt['current_x'][0]) torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth') np.savez(output_filename + '_loss.npz', loss=loss) im_optimized.save(output_filename + '_optimized_im.png')
def optimize(opt): dataset_name = 'cifar10' generator_name = 'stylegan2-cc' # class conditional stylegan transform = data.get_transform(dataset_name, 'imval') dset = data.get_dataset(dataset_name, opt.partition, load_w=False, transform=transform) total = len(dset) if opt.indices is None: start_idx = 0 end_idx = total else: start_idx = opt.indices[0] end_idx = opt.indices[1] generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=False) util.set_requires_grad(False, generator.generator) resnet = domain_classifier.define_classifier(dataset_name, 'imageclassifier') ### iterate ### for i in range(start_idx, end_idx): img, label = dset[i] print("Running img %d/%d" % (i, len(dset))) filename = os.path.join(opt.w_path, '%s_%06d.npy' % (opt.partition, i)) if os.path.isfile(filename): print(filename + ' found... skipping') continue img = img[None].cuda() with torch.no_grad(): pred_logit = resnet(img) _, pred_label = pred_logit.max(1) pred_label = pred_label.item() print("True label %d prd label %d" % (label, pred_label)) ckpt, loss = generator.optimize(img, pred_label) current_z = ckpt['current_z'].detach().cpu().numpy() np.save(filename, current_z)
def optimize(opt): dataset_name = 'car' generator_name = 'stylegan2' transform = data.get_transform(dataset_name, 'im2tensor') # loads the PIL image dset = data.get_dataset(dataset_name, opt.partition, load_w=False, transform=None) total = len(dset) if opt.indices is None: start_idx = 0 end_idx = total else: start_idx = opt.indices[0] end_idx = opt.indices[1] print("Optimizing dataset partition %s items %d to %d" % (opt.partition, start_idx, end_idx)) generator = domain_generator.define_generator(generator_name, dataset_name, load_encoder=True) util.set_requires_grad(False, generator.generator) util.set_requires_grad(False, generator.encoder) for i in range(start_idx, end_idx): (im, label, bbox, path) = dset[i] img_filename = os.path.splitext(os.path.basename(path))[0] print("Running %d / %d images: %s" % (i, end_idx, img_filename)) output_filename = os.path.join(opt.w_path, img_filename) if os.path.isfile(output_filename + '.pth'): print(output_filename + '.pth found... skipping') continue # scale image to 512 width width, height = im.size ratio = 512 / width new_width = 512 new_height = int(ratio * height) new_im = im.resize((new_width, new_height), Image.ANTIALIAS) print(im.size) print(new_im.size) bbox = [int(x * ratio) for x in bbox] # shift to center the bbox annotation cx = (bbox[2] + bbox[0]) // 2 cy = (bbox[3] + bbox[1]) // 2 print("%d --> %d" % (cx, new_width // 2)) print("%d --> %d" % (cy, new_height // 2)) offset_x = new_width // 2 - cx offset_y = new_height // 2 - cy im_tensor = transform(new_im) im_tensor, mask = data.transforms.shift_tensor(im_tensor, offset_y, offset_x) im_tensor = data.transforms.centercrop_tensor(im_tensor, 384, 512) mask = data.transforms.centercrop_tensor(mask, 384, 512) # now image size is at most 512 x 384 (could be smaller) # center car on 512x512 tensor disp_y = (512 - im_tensor.shape[1]) // 2 disp_x = (512 - im_tensor.shape[2]) // 2 centered_im = torch.ones((3, 512, 512)) * 0 centered_im[:, disp_y:disp_y + im_tensor.shape[1], disp_x:disp_x + im_tensor.shape[2]] = im_tensor centered_mask = torch.zeros_like(centered_im) centered_mask[:, disp_y:disp_y + im_tensor.shape[1], disp_x:disp_x + im_tensor.shape[2]] = mask ckpt, loss = generator.optimize(centered_im[None].cuda(), centered_mask[:1][None].cuda()) w_optimized = ckpt['current_z'] loss = np.array(loss).squeeze() im_optimized = renormalize.as_image(ckpt['current_x'][0]) torch.save({'w': w_optimized.detach().cpu()}, output_filename + '.pth') np.savez(output_filename + '_loss.npz', loss=loss) im_optimized.save(output_filename + '_optimized_im.png')
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool, recon_pool, WR_pool, visualizer, epoch, args, val_loader, fixed): iter_data_time = time.time() for i, (img, label, landmarks, img_path) in enumerate(train_loader): if img.size(0) != args.batch_size: continue img_cuda = img.cuda(non_blocking=True) if i % args.print_loss_freq == 0: iter_start_time = time.time() t_data = iter_start_time - iter_data_time visualizer.reset() # -------------------- forward & get aligned -------------------- theta = alignment(landmarks) grid = torch.nn.functional.affine_grid( theta, torch.Size((args.batch_size, 3, 112, 96))) # -------------------- generate password -------------------- z, dis_target, rand_z, rand_dis_target, \ inv_z, inv_dis_target, rand_inv_z, rand_inv_dis_target, \ rand_inv_2nd_z, rand_inv_2nd_dis_target = generate_code(args.passwd_length, args.batch_size, args.device, inv=True, use_minus_one=args.use_minus_one, gen_random_WR=True) real_aligned = grid_sample(img_cuda, grid) # (B, 3, h, w) real_aligned = real_aligned[:, [2, 1, 0], ...] fake = model_dict['G'](img, z.cpu()) fake_aligned = grid_sample(fake, grid) fake_aligned = fake_aligned[:, [2, 1, 0], ...] recon = model_dict['G'](fake, inv_z) recon_aligned = grid_sample(recon, grid) recon_aligned = recon_aligned[:, [2, 1, 0], ...] rand_fake = model_dict['G'](img, rand_z.cpu()) rand_fake_aligned = grid_sample(rand_fake, grid) rand_fake_aligned = rand_fake_aligned[:, [ 2, 1, 0, ], ...] rand_recon = model_dict['G'](fake, rand_inv_z) rand_recon_aligned = grid_sample(rand_recon, grid) rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...] rand_recon_2nd = model_dict['G'](fake, rand_inv_2nd_z) rand_recon_2nd_aligned = grid_sample(rand_recon_2nd, grid) rand_recon_2nd_aligned = rand_recon_2nd_aligned[:, [2, 1, 0], ...] # init loss dict for plot & print current_losses = {} # -------------------- D PART -------------------- # init set_requires_grad(model_dict['G_nets'], False) set_requires_grad(model_dict['D_nets'], True) optimizer_dict['D'].zero_grad() loss_D = 0. # ========== Face Recognition (FR) losses (L_{adv}, L_{rec\_cls}) ========== # FAKE FRs # M id_fake = model_dict['FR'](fake_aligned.detach())[0] loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device)) # R & WR id_recon = model_dict['FR'](recon_aligned.detach())[0] loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to( args.device)) id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0] loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon, label.to(args.device)) loss_D_FR_fake_total = args.lambda_FR_M * loss_D_FR_fake + loss_D_FR_recon \ + args.lambda_FR_WR * loss_D_FR_rand_recon loss_D_FR_fake_avg = loss_D_FR_fake_total / float(1. + args.lambda_FR_M + args.lambda_FR_WR) current_losses.update({ 'D_FR_M': loss_D_FR_fake.item(), 'D_FR_R': loss_D_FR_recon.item(), 'D_FR_WR': loss_D_FR_rand_recon.item(), }) # REAL FR id_real = model_dict['FR'](real_aligned)[0] loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device)) loss_D += args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5 current_losses.update({ 'D_FR_real': loss_D_FR_real.item(), 'D_FR_fake': loss_D_FR_fake_avg.item() }) # ========== GAN loss (L_{GAN}) ========== # fake all_M = torch.cat(( fake.detach().cpu(), rand_fake.detach().cpu(), ), 0) pred_D_M = model_dict['D'](fake_pool.query(all_M, batch_size=args.batch_size), 'M') loss_D_M = criterion_dict['GAN'](pred_D_M, False) # R pred_D_recon = model_dict['D'](recon_pool.query( recon.detach().cpu(), batch_size=args.batch_size), 'R') loss_D_recon = criterion_dict['GAN'](pred_D_recon, False) # WR all_WR = torch.cat( (rand_recon.detach().cpu(), rand_recon_2nd.detach().cpu()), 0) pred_D_WR = model_dict['D'](WR_pool.query(all_WR, batch_size=args.batch_size), 'WR') loss_D_WR = criterion_dict['GAN'](pred_D_WR, False) loss_D_fake_total = args.lambda_GAN_M * loss_D_M + \ args.lambda_GAN_recon * loss_D_recon + \ args.lambda_GAN_WR * loss_D_WR loss_D_fake_total_weights = args.lambda_GAN_M + \ args.lambda_GAN_recon + \ args.lambda_GAN_WR loss_D_GAN_fake = loss_D_fake_total / loss_D_fake_total_weights current_losses.update({ 'D_GAN_M': loss_D_M.item(), 'D_GAN_R': loss_D_recon.item(), 'D_GAN_WR': loss_D_WR.item() }) # real pred_D_real_M = model_dict['D'](img, 'M') pred_D_real_R = model_dict['D'](img, 'R') pred_D_real_WR = model_dict['D'](img, 'WR') loss_D_real_M = criterion_dict['GAN'](pred_D_real_M, True) loss_D_real_R = criterion_dict['GAN'](pred_D_real_R, True) loss_D_real_WR = criterion_dict['GAN'](pred_D_real_WR, True) loss_D_GAN_real = (args.lambda_GAN_M * loss_D_real_M + args.lambda_GAN_recon * loss_D_real_R + args.lambda_GAN_WR * loss_D_real_WR) / \ (args.lambda_GAN_M + args.lambda_GAN_recon + args.lambda_GAN_WR) loss_D += args.lambda_GAN * (loss_D_GAN_fake + loss_D_GAN_real) * 0.5 current_losses.update({ 'D_GAN_real': loss_D_GAN_real.item(), 'D_GAN_fake': loss_D_GAN_fake.item() }) current_losses['D'] = loss_D.item() # D backward and optimizer steps loss_D.backward() optimizer_dict['D'].step() # -------------------- G PART -------------------- # init set_requires_grad(model_dict['D_nets'], False) set_requires_grad(model_dict['G_nets'], True) optimizer_dict['G'].zero_grad() loss_G = 0 # ========== GAN loss (L_{GAN}) ========== pred_G_fake = model_dict['D'](fake, 'M') loss_G_GAN_fake = criterion_dict['GAN'](pred_G_fake, True) pred_G_recon = model_dict['D'](recon, 'R') loss_G_GAN_recon = criterion_dict['GAN'](pred_G_recon, True) pred_G_WR = model_dict['D'](rand_recon, 'WR') loss_G_GAN_WR = criterion_dict['GAN'](pred_G_WR, True) loss_G_GAN_total = args.lambda_GAN_M * loss_G_GAN_fake + \ args.lambda_GAN_recon * loss_G_GAN_recon + \ args.lambda_GAN_WR * loss_G_GAN_WR loss_G_GAN_total_weights = args.lambda_GAN_M + args.lambda_GAN_recon + args.lambda_GAN_WR loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights loss_G += args.lambda_GAN * loss_G_GAN current_losses.update({ 'G_GAN_M': loss_G_GAN_fake.item(), 'G_GAN_R': loss_G_GAN_recon.item(), 'G_GAN_WR': loss_G_GAN_WR.item(), 'G_GAN': loss_G_GAN.item() }) # ========== infoGAN loss (L_{aux}) ========== if args.lambda_dis > 0: fake_dis_logits = model_dict['Q'](infoGAN_input(img_cuda, fake)) infogan_fake_acc = 0 loss_G_fake_dis = 0 for dis_idx in range(args.passwd_length // 4): a = fake_dis_logits[dis_idx].max(dim=1)[1] b = dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_fake_acc += acc.item() loss_G_fake_dis += criterion_dict['DIS']( fake_dis_logits[dis_idx], dis_target[:, dis_idx]) infogan_fake_acc = infogan_fake_acc / float( args.passwd_length // 4) recon_dis_logits = model_dict['Q'](infoGAN_input(fake, recon)) infogan_recon_acc = 0 loss_G_recon_dis = 0 for dis_idx in range(args.passwd_length // 4): a = recon_dis_logits[dis_idx].max(dim=1)[1] b = inv_dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_recon_acc += acc.item() loss_G_recon_dis += criterion_dict['DIS']( recon_dis_logits[dis_idx], inv_dis_target[:, dis_idx]) infogan_recon_acc = infogan_recon_acc / float( args.passwd_length // 4) rand_recon_dis_logits = model_dict['Q'](infoGAN_input( fake, rand_recon)) infogan_rand_recon_acc = 0 loss_G_recon_rand_dis = 0 for dis_idx in range(args.passwd_length // 4): a = rand_recon_dis_logits[dis_idx].max(dim=1)[1] b = rand_inv_dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_rand_recon_acc += acc.item() loss_G_recon_rand_dis += criterion_dict['DIS']( rand_recon_dis_logits[dis_idx], rand_inv_dis_target[:, dis_idx]) infogan_rand_recon_acc = infogan_rand_recon_acc / float( args.passwd_length // 4) dis_loss_total = loss_G_fake_dis + loss_G_recon_dis + loss_G_recon_rand_dis dis_acc_total = infogan_fake_acc + infogan_recon_acc + infogan_rand_recon_acc dis_cnt = 3 loss_G += args.lambda_dis * dis_loss_total current_losses.update({ 'dis': dis_loss_total.item(), 'dis_acc': dis_acc_total / float(dis_cnt) }) # ========== Face Recognition (FR) loss (L_{adv}, L{rec_cls}})========== # (netFR must not be fixed) id_fake_G, fake_feat = model_dict['FR'](fake_aligned) loss_G_FR_fake = -criterion_dict['FR'](id_fake_G, label.to( args.device)) id_recon_G, recon_feat = model_dict['FR'](recon_aligned) loss_G_FR_recon = criterion_dict['FR'](id_recon_G, label.to(args.device)) id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned) loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G, label.to(args.device)) loss_G_FR_avg = (args.lambda_FR_M * loss_G_FR_fake + loss_G_FR_recon + args.lambda_FR_WR * loss_G_FR_rand_recon) /\ (args.lambda_FR_M + 1. + args.lambda_FR_WR) loss_G += args.lambda_FR * loss_G_FR_avg current_losses.update({ 'G_FR_M': loss_G_FR_fake.item(), 'G_FR_R': loss_G_FR_recon.item(), 'G_FR_WR': loss_G_FR_rand_recon.item(), 'G_FR': loss_G_FR_avg.item() }) # ========== Feature losses (L_{feat} is the sum of the three L_{dis}'s) ========== if args.feature_loss == 'cos': # make cos sim target FR_cos_sim_target = torch.empty(size=(args.batch_size, 1), dtype=torch.float32, device=args.device) FR_cos_sim_target.fill_(-1.) else: FR_cos_sim_target = None id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned) id_rand_recon_2nd_G, rand_recon_2nd_feat = model_dict['FR']( rand_recon_2nd_aligned) if args.lambda_Feat: loss_G_feat = get_feat_loss(fake_feat, rand_fake_feat, FR_cos_sim_target, args.feature_loss, criterion_dict) current_losses['G_feat'] = loss_G_feat.item() else: loss_G_feat = 0. if args.lambda_WR_Feat: loss_G_WR_feat = get_feat_loss(rand_recon_feat, rand_recon_2nd_feat, FR_cos_sim_target, args.feature_loss, criterion_dict) current_losses['G_WR_feat'] = loss_G_WR_feat.item() else: loss_G_WR_feat = 0. if args.lambda_false_recon_diff: loss_G_M_WR_feat = get_feat_loss(fake_feat, rand_recon_feat, FR_cos_sim_target, args.feature_loss, criterion_dict) current_losses['G_feat_M_WR'] = loss_G_M_WR_feat.item() else: loss_G_M_WR_feat = 0. loss_G += args.lambda_Feat * loss_G_feat + \ args.lambda_WR_Feat * loss_G_WR_feat + \ args.lambda_false_recon_diff * loss_G_M_WR_feat # ========== L1/Recon losses (L_1, L_{rec}) ========== loss_G_L1 = criterion_dict['L1'](fake, img_cuda) loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img_cuda) loss_G_recon = criterion_dict['L1'](recon, img_cuda) loss_G += args.lambda_L1 * loss_G_L1 + \ args.lambda_rand_recon_L1 * loss_G_rand_recon_L1 + \ args.lambda_G_recon * loss_G_recon current_losses.update({ 'L1_M': loss_G_L1.item(), 'recon': loss_G_recon.item(), 'L1_WR': loss_G_rand_recon_L1.item() }) current_losses['G'] = loss_G.item() # G backward and optimizer steps loss_G.backward() optimizer_dict['G'].step() # -------------------- LOGGING PART -------------------- if i % args.print_loss_freq == 0: t = (time.time() - iter_start_time) / args.batch_size visualizer.print_current_losses(epoch, i, current_losses, t, t_data) if args.display_id > 0 and i % args.plot_loss_freq == 0: visualizer.plot_current_losses(epoch, float(i) / len(train_loader), args, current_losses) if i % args.visdom_visual_freq == 0: save_result = i % args.update_html_freq == 0 current_visuals = OrderedDict() current_visuals['real'] = img.detach() current_visuals['fake'] = fake.detach() current_visuals['rand_fake'] = rand_fake.detach() current_visuals['recon'] = recon.detach() current_visuals['rand_recon'] = rand_recon.detach() current_visuals['rand_recon_2nd'] = rand_recon_2nd.detach() try: with time_limit(60): visualizer.display_current_results(current_visuals, epoch, save_result, args) except TimeoutException: visualizer.logger.log( 'TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1' .format(epoch, i)) # disable visdom display ever since args.display_id = -1 # +1 so that we do not save/test for 0th iteration if (i + 1) % args.save_iter_freq == 0: save_model(epoch, model_dict, optimizer_dict, args, iter=i, save_sep=False) if args.display_id > 0: visualizer.vis.save([args.name]) if (i + 1) % args.html_iter_freq == 0: validate(val_loader, model_dict, visualizer, epoch, args, fixed, i) if (i + 1) % args.print_loss_freq == 0: iter_data_time = time.time()
def train(train_loader, model_dict, criterion_dict, optimizer_dict, fake_pool, recon_pool, fake_pair_pool, WR_pool, visualizer, epoch, args, test_loader, fixed_z, fixed_rand_z): iter_data_time = time.time() for i, (img, label, landmarks, img_path) in enumerate(train_loader): iter_start_time = time.time() if i % args.print_loss_freq == 0: t_data = iter_start_time - iter_data_time visualizer.reset() batch_size = img.size(0) if args.lambda_dis > 0: # -------------------- generate password -------------------- z, dis_target, rand_z, rand_dis_target, inv_z, inv_dis_target, another_rand_z, another_rand_dis_target = generate_code(args.passwd_length, batch_size, args.device, inv=True) # -------------------- forward -------------------- # TODO: whether to detach fake = model_dict['G'](img, z.cpu()) rand_fake = model_dict['G'](img, rand_z.cpu()) if args.lambda_G_recon > 0: recon = model_dict['G'](fake, inv_z) rand_recon = model_dict['G'](fake, another_rand_z) else: fake = model_dict['G'](img) if args.lambda_G_recon > 0: recon = model_dict['G'](fake) # FR forward and FR losses theta = alignment(landmarks) grid = torch.nn.functional.affine_grid(theta, torch.Size((batch_size, 3, 112, 96))) real_aligned = torch.nn.functional.grid_sample(img.cuda(), grid) real_aligned = real_aligned[:, [2, 1, 0], ...] fake_aligned = torch.nn.functional.grid_sample(fake, grid) fake_aligned = fake_aligned[:, [2, 1, 0], ...] rand_fake_aligned = torch.nn.functional.grid_sample(rand_fake, grid) rand_fake_aligned = rand_fake_aligned[:, [2, 1, 0, ], ...] # (B, 3, h, w) if args.lambda_G_recon > 0: recon_aligned = torch.nn.functional.grid_sample(recon, grid) recon_aligned = recon_aligned[:, [2, 1, 0], ...] rand_recon_aligned = torch.nn.functional.grid_sample(rand_recon, grid) rand_recon_aligned = rand_recon_aligned[:, [2, 1, 0], ...] current_losses = {} # -------------------- D PART -------------------- if optimizer_dict['D'] is not None: set_requires_grad(model_dict['G_nets'], False) set_requires_grad(model_dict['D_nets'], True) optimizer_dict['D'].zero_grad() id_real = model_dict['FR'](real_aligned)[0] loss_D_FR_real = criterion_dict['FR'](id_real, label.to(args.device)) cnt_FR_fake = 0. loss_D_FR_fake_total = 0 if args.train_M: id_fake = model_dict['FR'](fake_aligned.detach())[0] id_rand_fake = model_dict['FR'](rand_fake_aligned.detach())[0] loss_D_FR_fake = criterion_dict['FR'](id_fake, label.to(args.device)) loss_D_FR_rand_fake = criterion_dict['FR'](id_rand_fake, label.to(args.device)) loss_D_FR_fake_total += loss_D_FR_fake + loss_D_FR_rand_fake cnt_FR_fake += 2. current_losses.update({'D_FR_fake': loss_D_FR_fake.item(), 'D_FR_rand': loss_D_FR_rand_fake.item(), # 'D_FR_rand_recon': loss_D_FR_rand_recon.item() }) if args.recon_FR: # TODO: rand_fake_recon FR loss? id_recon = model_dict['FR'](recon_aligned.detach())[0] loss_D_FR_recon = -criterion_dict['FR'](id_recon, label.to(args.device)) if args.lambda_FR_WR: id_rand_recon = model_dict['FR'](rand_recon_aligned.detach())[0] loss_D_FR_rand_recon = criterion_dict['FR'](id_rand_recon, label.to(args.device)) current_losses.update({'D_FR_rand_recon': loss_D_FR_rand_recon.item() }) else: loss_D_FR_rand_recon = 0. loss_D_FR_fake_total += loss_D_FR_recon + args.lambda_FR_WR * loss_D_FR_rand_recon cnt_FR_fake += 1. + args.lambda_FR_WR current_losses.update({'D_FR_recon': loss_D_FR_recon.item(), # 'D_FR_rand_recon': loss_D_FR_rand_recon.item() }) loss_D_FR_fake_avg = loss_D_FR_fake_total / float(cnt_FR_fake) loss_D = args.lambda_FR * (loss_D_FR_real + loss_D_FR_fake_avg) * 0.5 current_losses.update({'D_FR_real': loss_D_FR_real.item(), 'D_FR_fake': loss_D_FR_fake_avg.item() # 'D_FR_fake': loss_D_FR_fake.item(), # 'D_FR_rand': loss_D_FR_rand_fake.item(), # 'D_FR_rand_recon': loss_D_FR_rand_recon.item() }) # GAN loss if args.lambda_GAN > 0: # real if args.recon_pair_GAN: assert args.single_GAN_recon_only real_input = torch.cat((img.cuda(), recon.detach()), dim=1) else: real_input = img pred_D_real = model_dict['D'](real_input) loss_D_real = criterion_dict['GAN'](pred_D_real, True) # fake loss_D_fake_total = 0. loss_D_fake_total_weights = 0. # recon if args.lambda_GAN_recon: if args.recon_pair_GAN: recon_input_to_pool = torch.cat((recon.detach().cpu(), img), dim=1) else: recon_input_to_pool = recon.detach().cpu() pred_D_recon = model_dict['D'](recon_pool.query(recon_input_to_pool)) loss_D_recon = criterion_dict['GAN'](pred_D_recon, False) loss_D_fake_total += args.lambda_GAN_recon * loss_D_recon loss_D_fake_total_weights += args.lambda_GAN_recon current_losses['D_recon'] = loss_D_recon.item() if not args.single_GAN_recon_only: assert args.lambda_pair_GAN == 0 if args.train_M: all_M = torch.cat((fake.detach().cpu(), rand_fake.detach().cpu(), ), 0) pred_D_M = model_dict['D'](fake_pool.query(all_M)) loss_D_M = criterion_dict['GAN'](pred_D_M, False) loss_D_fake_total += args.lambda_GAN_M * loss_D_M loss_D_fake_total_weights += args.lambda_GAN_M current_losses['D_M'] = loss_D_M.item() if args.lambda_GAN_WR: pred_D_WR = model_dict['D'](WR_pool.query(rand_recon.detach().cpu())) loss_D_WR = criterion_dict['GAN'](pred_D_WR, False) loss_D_fake_total += args.lambda_GAN_WR * loss_D_WR loss_D_fake_total_weights += args.lambda_GAN_WR current_losses['D_WR'] = loss_D_WR.item() loss_D_fake = loss_D_fake_total / loss_D_fake_total_weights loss_D += args.lambda_GAN * (loss_D_fake + loss_D_real) * 0.5 current_losses.update({ 'D_real': loss_D_real.item(), 'D_fake': loss_D_fake.item() }) if args.lambda_pair_GAN > 0: loss_pair_fake_total = 0 loss_pair_real_total = 0 loss_pair_cnt = 0. if args.train_M: pred_pair_real1 = model_dict['pair_D'](torch.cat((img.cuda(), fake.detach()), 1)) pred_pair_real2 = model_dict['pair_D'](torch.cat((img.cuda(), rand_fake.detach()), 1)) all_fake_pair = torch.cat((torch.cat((fake.detach().cpu(), img), 1), torch.cat((rand_fake.detach().cpu(), img), 1), ), 0) pred_pair_fake = model_dict['pair_D'](fake_pair_pool.query(all_fake_pair)) loss_pair_M_real = (criterion_dict['GAN'](pred_pair_real1, True) + criterion_dict['GAN'](pred_pair_real2, True)) / 2. loss_pair_M_fake = criterion_dict['GAN'](pred_pair_fake, False) loss_pair_real_total += loss_pair_M_real loss_pair_fake_total += loss_pair_M_fake loss_pair_cnt += 1 pred_pair_WR_real = model_dict['pair_D'](torch.cat((img.cuda(), rand_recon.detach()), 1)) pred_pair_WR_fake = model_dict['pair_D'](WR_pool.query(torch.cat((rand_recon.detach().cpu(), img), 1))) loss_pair_WR_real = criterion_dict['GAN'](pred_pair_WR_real, True) loss_pair_WR_fake = criterion_dict['GAN'](pred_pair_WR_fake, False) loss_pair_real_total += args.multiple_pair_WR_GAN * loss_pair_WR_real loss_pair_fake_total += args.multiple_pair_WR_GAN * loss_pair_WR_fake loss_pair_cnt += args.multiple_pair_WR_GAN loss_pair_D_real = loss_pair_real_total / loss_pair_cnt # (loss_pair_M_real + args.multiple_pair_WR_GAN * loss_pair_WR_real) / (1. + args.multiple_pair_WR_GAN) loss_pair_D_fake = loss_pair_fake_total / loss_pair_cnt #(loss_pair_M_fake + args.multiple_pair_WR_GAN * loss_pair_WR_fake) / (1. + args.multiple_pair_WR_GAN) current_losses.update({ 'pair_D_fake': loss_pair_D_fake.item(), 'pair_D_real': loss_pair_D_real.item() }) loss_D += args.lambda_pair_GAN * (loss_pair_D_fake + loss_pair_D_real) * 0.5 current_losses['D'] = loss_D.item() # D backward and optimizer steps loss_D.backward() if args.gan_mode == 'wgangp': real_to_wgangp = torch.cat((img, img), 0).to(args.device) if np.random.rand() > 0.5: fake_selected = fake.detach() else: fake_selected = rand_fake.detach() fake_to_wgangp = torch.cat((fake_selected, rand_recon.detach()), 0) loss_gp, gradients = models.cal_gradient_penalty(model_dict['D'], real_to_wgangp, fake_to_wgangp, args.device) # print('gradeints abs/l2 mean:', gradients[0], gradients[1]) loss_gp *= args.lambda_GAN # print('loss_gp', loss_gp.item()) loss_gp.backward() optimizer_dict['D'].step() # -------------------- G PART -------------------- # init set_requires_grad(model_dict['D_nets'], False) set_requires_grad(model_dict['G_nets'], True) optimizer_dict['G'].zero_grad() loss_G = 0 # GAN loss if args.lambda_GAN > 0: loss_G_GAN_total = 0. loss_G_GAN_total_weights = 0. # recon if args.lambda_GAN_recon: if args.recon_pair_GAN: recon_input_G = torch.cat((recon, img.cuda()), dim=1) else: recon_input_G = recon pred_G_recon = model_dict['D'](recon_input_G) loss_G_recon = criterion_dict['GAN'](pred_G_recon, True) loss_G_GAN_total += args.lambda_GAN_recon * loss_G_recon loss_G_GAN_total_weights += args.lambda_GAN_recon current_losses['G_recon'] = loss_G_recon.item() if not args.single_GAN_recon_only: if args.train_M: pred_G_fake = model_dict['D'](fake) pred_G_rand_fake = model_dict['D'](rand_fake) loss_G_fake = criterion_dict['GAN'](pred_G_fake, True) loss_G_rand_fake = criterion_dict['GAN'](pred_G_rand_fake, True) loss_G_GAN_total += args.lambda_GAN_M * 0.5 * (loss_G_fake + loss_G_rand_fake) loss_G_GAN_total_weights += args.lambda_GAN_M current_losses['G_M'] = 0.5 * (loss_G_fake.item() + loss_G_rand_fake.item()) pred_G_WR = model_dict['D'](rand_recon) loss_G_WR = criterion_dict['GAN'](pred_G_WR, True) current_losses['G_WR'] = loss_G_WR.item() loss_G_GAN_total += args.lambda_GAN_WR * loss_G_WR loss_G_GAN_total_weights += args.lambda_GAN_WR loss_G_GAN = loss_G_GAN_total / loss_G_GAN_total_weights loss_G += args.lambda_GAN * loss_G_GAN current_losses.update({'G_GAN': loss_G_GAN.item(), }) if args.lambda_pair_GAN > 0: loss_pair_G_total = 0 cnt_pair_G = 0. if args.train_M: pred_pair_fake1_G = model_dict['pair_D'](torch.cat((fake, img.cuda()), 1)) pred_pair_fake2_G = model_dict['pair_D'](torch.cat((rand_fake, img.cuda()), 1)) loss_pair_M_G = (criterion_dict['GAN'](pred_pair_fake1_G, True) + criterion_dict['GAN'](pred_pair_fake2_G, True)) / 2. loss_pair_G_total += loss_pair_M_G cnt_pair_G += 1. pred_pair_fake3_G = model_dict['pair_D'](torch.cat((rand_recon, img.cuda()), 1)) loss_pair_WR_G = criterion_dict['GAN'](pred_pair_fake3_G, True) loss_pair_G_total += args.multiple_pair_WR_GAN * loss_pair_WR_G cnt_pair_G += args.multiple_pair_WR_GAN loss_pair_G_avg = loss_pair_G_total / cnt_pair_G loss_G += args.lambda_pair_GAN * loss_pair_G_avg current_losses['pair_G'] = loss_pair_G_avg.item() # infoGAN loss def infoGAN_input(img1, img2): if args.use_minus_Q: return img2 - img1 else: return torch.cat((img1, img2), 1) if args.lambda_dis > 0: infogan_acc = 0 infogan_inv_acc = 0 infogan_rand_acc = 0 infogan_recon_rand_acc = 0 dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), fake)) loss_G_dis = 0 for dis_idx in range(args.passwd_length // 4): a = dis_logits[dis_idx].max(dim=1)[1] b = dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_acc += acc.item() loss_G_dis += criterion_dict['DIS'](dis_logits[dis_idx], dis_target[:, dis_idx]) infogan_acc = infogan_acc / float(args.passwd_length // 4) inv_dis_logits = model_dict['Q'](infoGAN_input(fake, recon)) loss_G_inv_dis = 0 for dis_idx in range(args.passwd_length // 4): a = inv_dis_logits[dis_idx].max(dim=1)[1] b = inv_dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_inv_acc += acc.item() loss_G_inv_dis += criterion_dict['DIS'](inv_dis_logits[dis_idx], inv_dis_target[:, dis_idx]) infogan_inv_acc = infogan_inv_acc / float(args.passwd_length // 4) rand_dis_logits = model_dict['Q'](infoGAN_input(img.cuda(), rand_fake)) loss_G_rand_dis = 0 for dis_idx in range(args.passwd_length // 4): a = rand_dis_logits[dis_idx].max(dim=1)[1] b = rand_dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_rand_acc += acc.item() loss_G_rand_dis += criterion_dict['DIS'](rand_dis_logits[dis_idx], rand_dis_target[:, dis_idx]) infogan_rand_acc = infogan_rand_acc / float(args.passwd_length // 4) recon_rand_dis_logits = model_dict['Q'](infoGAN_input(fake, rand_recon)) loss_G_recon_rand_dis = 0 for dis_idx in range(args.passwd_length // 4): a = recon_rand_dis_logits[dis_idx].max(dim=1)[1] b = another_rand_dis_target[:, dis_idx] acc = torch.eq(a, b).type(torch.float).mean() infogan_recon_rand_acc += acc.item() loss_G_recon_rand_dis += criterion_dict['DIS'](recon_rand_dis_logits[dis_idx], another_rand_dis_target[:, dis_idx]) infogan_recon_rand_acc = infogan_recon_rand_acc / float(args.passwd_length // 4) # current_losses.update({'G_dis': loss_G_dis.item(), # 'G_inv_dis': loss_G_inv_dis.item(), # 'G_dis_acc': infogan_acc, # 'G_inv_dis_acc': infogan_inv_acc, # 'G_rand_dis': loss_G_rand_dis.item(), # 'G_recon_rand_dis': loss_G_recon_rand_dis.item(), # 'G_rand_dis_acc': infogan_rand_acc, # 'G_recon_rand_dis_acc': infogan_recon_rand_acc # }) loss_dis = (loss_G_dis + loss_G_inv_dis + loss_G_rand_dis + loss_G_recon_rand_dis) dis_acc = (infogan_acc + infogan_inv_acc + infogan_rand_acc + infogan_recon_rand_acc) / 4. loss_G += args.lambda_dis * loss_dis current_losses.update({ 'dis': loss_dis.item(), 'dis_acc': dis_acc }) # FR loss, netFR must not be fixed loss_G_FR_total = 0 cnt_G_FR = 0. if args.train_M: id_fake_G, fake_feat = model_dict['FR'](fake_aligned) loss_G_FR = -criterion_dict['FR'](id_fake_G, label.to(args.device)) # current_losses['G_FR'] = loss_G_FR.item() id_rand_fake_G, rand_fake_feat = model_dict['FR'](rand_fake_aligned) loss_G_FR_rand = -criterion_dict['FR'](id_rand_fake_G, label.to(args.device)) # current_losses['G_FR_rand'] = loss_G_FR_rand.item() loss_G_FR_total += loss_G_FR + loss_G_FR_rand cnt_G_FR += 2 if args.feature_loss == 'cos': FR_cos_sim_target = torch.empty(size=(batch_size, 1), dtype=torch.float32, device=args.device) FR_cos_sim_target.fill_(-1.) if args.lambda_Feat: if args.feature_loss == 'cos': loss_G_feat = criterion_dict['Feat'](fake_feat, rand_fake_feat, target=FR_cos_sim_target) else: loss_G_feat = -criterion_dict['Feat'](fake_feat, rand_fake_feat) current_losses['G_feat'] = loss_G_feat.item() loss_G += args.lambda_Feat * loss_G_feat if args.lambda_G_recon: id_recon_G, recon_feat = model_dict['FR'](recon_aligned) if args.lambda_FR_WR: id_rand_recon_G, rand_recon_feat = model_dict['FR'](rand_recon_aligned) if args.lambda_recon_Feat: if args.feature_loss == 'cos': loss_G_recon_feat = criterion_dict['Feat'](recon_feat, rand_recon_feat, target=FR_cos_sim_target) else: loss_G_recon_feat = -criterion_dict['Feat'](recon_feat, rand_recon_feat) current_losses['G_recon_feat'] = loss_G_recon_feat.item() loss_G += args.lambda_recon_Feat * loss_G_recon_feat if args.lambda_false_recon_diff: if args.feature_loss == 'cos': loss_G_false_recon_feat =criterion_dict['Feat'](fake_feat, rand_recon_feat, target=FR_cos_sim_target) else: loss_G_false_recon_feat =-criterion_dict['Feat'](fake_feat, rand_recon_feat) current_losses['G_false_recon_feat'] = loss_G_false_recon_feat.item() loss_G += args.lambda_false_recon_diff * loss_G_false_recon_feat if args.recon_FR: loss_G_FR_recon = criterion_dict['FR'](id_recon_G, label.to(args.device)) # current_losses['G_FR_recon'] = loss_G_FR_recon.item() if args.lambda_FR_WR: loss_G_FR_rand_recon = -criterion_dict['FR'](id_rand_recon_G, label.to(args.device)) else: loss_G_FR_rand_recon = 0. # current_losses['G_FR_rand_recon'] = loss_G_FR_rand_recon.item() loss_G_FR_total += loss_G_FR_recon + args.lambda_FR_WR * loss_G_FR_rand_recon cnt_G_FR += 1. + args.lambda_FR_WR loss_G_FR_avg = loss_G_FR_total / cnt_G_FR loss_G += args.lambda_FR * loss_G_FR_avg current_losses['G_FR'] = loss_G_FR_avg.item() # loss_L1 = 0 # cnt_loss_L1 = 0 if args.lambda_L1 > 0: loss_G_L1 = criterion_dict['L1'](fake, img.cuda()) current_losses['L1'] = loss_G_L1.item() # loss_L1 += loss_G_L1.item() # cnt_loss_L1 += 1 loss_G += args.lambda_L1 * loss_G_L1 if args.lambda_rand_L1 > 0: loss_G_rand_L1 = criterion_dict['L1'](rand_fake, img.cuda()) current_losses['rand_L1'] = loss_G_rand_L1.item() # loss_L1 += loss_G_rand_L1.item() # cnt_loss_L1 += 1 loss_G += args.lambda_rand_L1 * loss_G_rand_L1 if args.lambda_rand_recon_L1 > 0: loss_G_rand_recon_L1 = criterion_dict['L1'](rand_recon, img.cuda()) current_losses['wrong_recon_L1'] = loss_G_rand_recon_L1.item() # loss_L1 += loss_G_rand_recon_L1.item() # cnt_loss_L1 += 1 loss_G += args.lambda_rand_recon_L1 * loss_G_rand_recon_L1 # current_losses['L1'] = loss_L1 / float(cnt_loss_L1) if args.lambda_G_recon > 0: loss_G_recon = criterion_dict['L1'](recon, img.cuda()) loss_G += args.lambda_G_recon * loss_G_recon current_losses['recon'] = loss_G_recon.item() if args.lambda_G_rand_recon > 0: if args.use_minus_one: inv_rand_z = rand_z * -1 else: inv_rand_z = 1.0 - rand_z rand_fake_recon = model_dict['G'](rand_fake, inv_rand_z) loss_G_rand_recon = criterion_dict['L1'](rand_fake_recon, img.cuda()) loss_G += args.lambda_G_rand_recon * loss_G_rand_recon current_losses['another_recon'] = loss_G_rand_recon.item() current_losses['G'] = loss_G.item() # G backward and optimizer steps loss_G.backward() optimizer_dict['G'].step() # -------------------- LOGGING PART -------------------- if i % args.print_loss_freq == 0: t = (time.time() - iter_start_time) / batch_size visualizer.print_current_losses(epoch, i, current_losses, t, t_data) if args.display_id > 0 and i % args.plot_loss_freq == 0: visualizer.plot_current_losses(epoch, float(i) / len(train_loader), args, current_losses) if args.print_gradient: for net_name, net in model_dict.items(): # if net_name != 'Q': # continue if isinstance(net, list): continue print(('================ NET %s ================' % net_name)) for name, param in net.named_parameters(): print_param_info(name, param, print_std=True) if i % args.visdom_visual_freq == 0: save_result = i % args.update_html_freq == 0 current_visuals = OrderedDict() current_visuals['real'] = img.detach() current_visuals['fake'] = fake.detach() current_visuals['rand_fake'] = rand_fake.detach() if args.lambda_G_recon: current_visuals['recon'] = recon.detach() current_visuals['rand_recon'] = rand_recon.detach() if args.lambda_G_rand_recon > 0: current_visuals['rand_fake_recon'] = rand_fake_recon.detach() current_visuals['real_aligned'] = real_aligned.detach() current_visuals['fake_aligned'] = fake_aligned.detach() current_visuals['rand_fake_aligned'] = rand_fake_aligned.detach() if args.lambda_G_recon: current_visuals['recon_aligned'] = recon_aligned.detach() current_visuals['rand_recon_aligned'] = rand_recon_aligned.detach() try: with time_limit(60): visualizer.display_current_results(current_visuals, epoch, save_result, args) except TimeoutException: visualizer.logger.log('TIME OUT visualizer.display_current_results epoch:{} iter:{}. Change display_id to -1'.format(epoch, i)) args.display_id = -1 if (i + 1) % args.save_iter_freq == 0: save_model(epoch, model_dict, optimizer_dict, args, iter=i) if args.display_id > 0: visualizer.vis.save([args.name]) visualizer.overview_vis.save(['overview']) if (i + 1) % args.html_iter_freq == 0: test(test_loader, model_dict, criterion_dict, visualizer, epoch, args, fixed_z, fixed_rand_z, i) if (i + 1) % args.print_loss_freq == 0: iter_data_time = time.time()