def KL_of_encoded_G_output(generator, encoder, z, batch_N, session, alt_mix_z, mix_N): KL_minimizer = KLN01Loss(direction=args.KL, minimize=True) #utils.populate_z(z, args.nz+args.n_label, args.noise, batch_N) assert (alt_mix_z is None or alt_mix_z.size()[0] == z.size()[0]) # Batch sizes must match mix_style_layer_begin = np.random.randint( low=1, high=(session.phase + 1)) if not alt_mix_z is None else -1 z_intermediates = utils.gen_seq([(z, 0, mix_style_layer_begin), (alt_mix_z, mix_style_layer_begin, -1)], generator, session, retain_intermediate_results=True) assert (z_intermediates[0].size()[0] == z.size()[0] ) # Batch sizes must remain invariant fake = z_intermediates[1 if not alt_mix_z is None else 0] z_intermediate = z_intermediates[0][:mix_N, :] egz = encoder(fake, session.getResoPhase(), session.alpha, args.use_ALQ) if mix_style_layer_begin > -1: egz_intermediate = utils.gen_seq( [(egz[:mix_N, :], 0, mix_style_layer_begin)], generator, session) # Or we could just call generator directly, ofc. assert (egz_intermediate.size()[0] == z_intermediate.size()[0]) else: egz_intermediate = z_intermediate = None # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g return egz, KL_minimizer( egz) * args.fake_G_KL_scale, egz_intermediate, z_intermediate
def reconstruct(input_image, encoder, generator, session, style_i=-1, style_layer_begin=0, style_layer_end=-1): with torch.no_grad(): style_input = encoder(Variable(input_image), session.getResoPhase(), session.alpha, args.use_ALQ).detach() #replicateStyle = True #if replicateStyle: z = style_input[style_i].repeat(style_input.size()[0],1) # Repeat the image #0 for all the image styles #else: # z = None #The call would unwrap as follows: #z_w = generator(None,None, session.phase, session.alpha, style_input = z, style_layer_begin=0, style_layer_end=style_layer_begin).detach() #z_w = generator(z_w, None, session.phase, session.alpha, style_input = style_input, style_layer_begin=style_layer_begin, style_layer_end=style_layer_end) #gex = generator(z_w, None, session.phase, session.alpha, style_input = z, style_layer_begin=style_layer_end, style_layer_end=-1) gex = utils.gen_seq([ (z, 0, style_layer_begin), (style_input, style_layer_begin, style_layer_end), (z, style_layer_end, -1), ], generator, session).detach() # if gex.size()[0] > 1: # print("Norm of diff: {} / {}".format((gex[0][0] - gex[1][0]).norm(), (gex[0][0] - gex[1][0]).max() )) return gex.data[:]
def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages = True): generator.eval() with torch.no_grad(): save_root = args.sample_dir if args.sample_dir != None else args.save_dir utils.requires_grad(generator, False) # Total number is samplesRepeatN * colN * rowN # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here. samplesRepeatN = max(1, int(args.sample_N / 128)) if not collateImages else 1 reso = session.getReso() if not collateImages: special_dir = '{}/sample/{}'.format(save_root, str(global_i).zfill(6)) while os.path.exists(special_dir): special_dir += '_' os.makedirs(special_dir) for outer_count in range(samplesRepeatN): colN = 1 if not collateImages else min(10, int(np.ceil(args.sample_N / 4.0))) rowN = 128 if not collateImages else min(5, int(np.ceil(args.sample_N / 4.0))) images = [] myzs = [] for _ in range(rowN): myz = {} for i in range(2): myz0 = Variable(torch.randn(args.n_label * colN, args.nz+1)).to(device=args.device) myz[i] = utils.normalize(myz0) #myz[i], input_class = utils.split_labels_out_of_latent(myz0) if args.randomMixN <= 1: new_imgs = generator( myz[0], None, # input_class, session.phase, session.alpha).detach() #.data.cpu() else: max_i = session.phase style_layer_begin = 2 #np.random.randint(low=1, high=(max_i+1)) print("Cut at layer {} for the next {} random samples.".format(style_layer_begin, (myz[0].shape))) new_imgs = utils.gen_seq([ (myz[0], 0, style_layer_begin), (myz[1], style_layer_begin, -1) ], generator, session).detach() images.append(new_imgs) myzs.append(myz[0]) # TODO: Support mixed latents with rejection sampling if args.rejection_sampling > 0 and not collateImages: filtered_images = [] batch_size = 8 for b in range(int(len(images) / batch_size)): img_batch = images[b*batch_size:(b+1)*batch_size] reco_z = session.encoder(Variable(torch.cat(img_batch,0)), session.getResoPhase(), session.alpha, args.use_ALQ).detach() cost_z = utils.mismatchV(torch.cat(myzs[b*batch_size:(b+1)*batch_size],0), reco_z, 'cos') _, ii = torch.sort(cost_z) keeper = args.rejection_sampling / 100.0 keepN = max(1, int(len(ii)*keeper)) for iindex in ii[:keepN]: filtered_images.append(img_batch[iindex]) # filtered_images.append(img_batch[ii[:keepN]] ) #retain the best ones only images = filtered_images if collateImages: sample_dir = '{}/sample'.format(save_root) if not os.path.exists(sample_dir): os.makedirs(sample_dir) save_path = '{}/{}.png'.format(sample_dir,str(global_i + 1).zfill(6)) torchvision.utils.save_image( torch.cat(images, 0), save_path, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) # Hacky but this is an easy way to rescale the images to nice big lego format: im = Image.open(save_path) im2 = im.resize((1024, 512 if reso < 256 else 1024)) im2.save(save_path) #if writer: # writer.add_image('samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase) else: print("Generating samples at {}...".format(special_dir)) for ii, img in enumerate(images): ipath = '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii+outer_count*128) #print(ipath) torchvision.utils.save_image( img, ipath, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) generator.train()
def encoder_train(session, real_image, generatedImagePool, batch_N, match_x, stats, kls, margin): encoder = session.encoder generator = session.generator encoder.zero_grad() generator.zero_grad() x = Variable(real_image).to(device=args.device) #async=(args.gpu_count>1)) KL_maximizer = KLN01Loss(direction=args.KL, minimize=False) KL_minimizer = KLN01Loss(direction=args.KL, minimize=True) e_losses = [] flipInvarianceLayer = args.flip_invariance_layer flipX = flipInvarianceLayer > -1 and session.phase >= flipInvarianceLayer phiAdaCotrain = args.phi_ada_cotrain if flipX: phiAdaCotrain = True #global adaparams if phiAdaCotrain: for b in generator.module.adanorm_blocks: if not b is None and not b.mod is None: for param in b.mod.parameters(): param.requires_grad_(True) if flipX: x_in = x[0:int(x.size()[0] / 2), :, :, :] x_mirror = x_in.clone().detach().requires_grad_(True).to( device=args.device).flip(dims=[3]) x[int(x.size()[0] / 2):, :, :, :] = x_mirror real_z = encoder(x, session.getResoPhase(), session.alpha, args.use_ALQ) if args.use_real_x_KL: # KL_real: - \Delta( e(X) , Z ) -> max_e if not flipX: KL_real = KL_minimizer(real_z) * args.real_x_KL_scale else: # Treat the KL div of each direction of the data as separate distributions z_in = real_z[0:int(x.size()[0] / 2), :] z_in_mirror = real_z[int(x.size()[0] / 2):, :] KL_real = (KL_minimizer(z_in) + KL_minimizer(z_in_mirror)) * args.real_x_KL_scale / 2 e_losses.append(KL_real) stats['real_mean'] = KL_minimizer.samples_mean.data.mean().item() stats['real_var'] = KL_minimizer.samples_var.data.mean().item() stats['KL_real'] = KL_real.data.item() kls = "{0:.3f}".format(stats['KL_real']) if flipX: x_reco_in = utils.gen_seq( [ ( z_in, 0, flipInvarianceLayer ), # Rotation of x_in, the ID of x_in_mirror (= the ID of x_in) (z_in_mirror, flipInvarianceLayer, -1) ], session.generator, session) x_reco_in_mirror = utils.gen_seq( [ (z_in_mirror, 0, flipInvarianceLayer), # Vice versa (z_in, flipInvarianceLayer, -1) ], session.generator, session) if args.match_x_metric == 'robust': loss_flip = torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_in - x_reco_in).view(-1, x_in.size()[1]*x_in.size()[2]*x_in.size()[3]) )) * match_x * 0.2 + \ torch.mean(session.adaptive_loss[session.getResoPhase()].lossfun((x_mirror - x_reco_in_mirror).view(-1, x_mirror.size()[1]*x_mirror.size()[2]*x_mirror.size()[3]) )) * match_x * 0.2 else: loss_flip = (utils.mismatch(x_in, x_reco_in, args.match_x_metric) + utils.mismatch(x_mirror, x_reco_in_mirror, args.match_x_metric)) * args.match_x loss_flip.backward(retain_graph=True) stats['loss_flip'] = loss_flip.data.item() stats['x_reconstruction_error'] = loss_flip.data.item() print('Flip loss: {}'.format(stats['loss_flip'])) else: if args.use_loss_x_reco: recon_x = generator(real_z, None, session.phase, session.alpha) # match_x: E_x||g(e(x)) - x|| -> min_e if args.match_x_metric == 'robust': err_simple = utils.mismatch(recon_x, x, 'L1') * match_x err = torch.mean( session.adaptive_loss[session.getResoPhase()].lossfun( (recon_x - x).view( -1, x.size()[1] * x.size()[2] * x.size()[3]))) * match_x * 0.2 print("err vs. ROBUST err: {} / {}".format(err_simple, err)) else: err_simple = utils.mismatch(recon_x, x, args.match_x_metric) * match_x err = err_simple if phiAdaCotrain: err.backward(retain_graph=True) else: e_losses.append(err) stats['x_reconstruction_error'] = err.data.item() if phiAdaCotrain: for b in session.generator.module.adanorm_blocks: if not b is None and not b.mod is None: for param in b.mod.parameters(): param.requires_grad_(False) if args.use_loss_fake_D_KL: # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify mix_ratio = args.stylemix_E #0.25 mix_N = int(mix_ratio * batch_N) z = Variable(torch.FloatTensor(batch_N, args.nz, 1, 1)).to( device=args.device) #async=(args.gpu_count>1)) utils.populate_z(z, args.nz + args.n_label, args.noise, batch_N) if session.phase > 0 and mix_N > 0: alt_mix_z = Variable(torch.FloatTensor(mix_N, args.nz, 1, 1)).to( device=args.device) #async=(args.gpu_count>1)) utils.populate_z(alt_mix_z, args.nz + args.n_label, args.noise, mix_N) alt_mix_z = torch.cat( (alt_mix_z, z[mix_N:, :]), dim=0) if mix_N < z.size()[0] else alt_mix_z else: alt_mix_z = None with torch.no_grad(): style_layer_begin = np.random.randint( low=1, high=(session.phase + 1)) if not alt_mix_z is None else -1 fake = utils.gen_seq([(z, 0, style_layer_begin), (alt_mix_z, style_layer_begin, -1)], generator, session).detach() if session.alpha >= 1.0: fake = generatedImagePool.query(fake.data) fake.requires_grad_() # e(g(Z)) egz = encoder(fake, session.getResoPhase(), session.alpha, args.use_ALQ) # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale m = margin if m > 0.0: KL_loss = torch.min( torch.ones_like(KL_fake) * m / 2, torch.max(-torch.ones_like(KL_fake) * m, KL_real + KL_fake) ) # KL_fake is always negative with abs value typically larger than KL_real. Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin. else: KL_loss = KL_real + KL_fake e_losses.append(KL_loss) stats['fake_mean'] = KL_maximizer.samples_mean.data.mean() stats['fake_var'] = KL_maximizer.samples_var.data.mean() stats['KL_fake'] = -KL_fake.data.item() stats['KL_loss'] = KL_loss.data.item() kls = "{0}/{1:.3f}".format(kls, stats['KL_fake']) # Update e if len(e_losses) > 0: e_loss = sum(e_losses) e_loss.backward() stats['E_loss'] = np.float32(e_loss.data.cpu()) session.optimizerD.step() if flipX: session.optimizerA.step( ) #The AdaNorm params need to be updated separately since they are on "generator side" return kls