def reconstruction_dryrun(generator, encoder, loader, session): generator.eval() encoder.eval() utils.requires_grad(generator, False) utils.requires_grad(encoder, False) reso = 4 * 2**session.phase warmup_rounds = 200 print('Warm-up rounds: {}'.format(warmup_rounds)) if session.phase < 1: dataset = data.Utils.sample_data(loader, 4, reso) else: dataset = data.Utils.sample_data2(loader, 4, reso, session) real_image, _ = next(dataset) x = utils.conditional_to_cuda(Variable(real_image)) for i in range(warmup_rounds): ex = encoder(x, session.phase, session.alpha, args.use_ALQ).detach() ex, label = utils.split_labels_out_of_latent(ex) gex = generator(ex, label, session.phase, session.alpha).detach() encoder.train() generator.train()
def reconstruct(input_image, encoder, generator, session): with torch.no_grad(): ex = encoder(Variable(input_image), session.phase, session.alpha, args.use_ALQ).detach() ex, label = utils.split_labels_out_of_latent(ex) gex = generator(ex, label, session.phase, session.alpha).detach() return gex.data[:]
def extract_latent_presentation_from_image(input_image, encoder, session): ex = encoder(Variable(input_image, volatile=True), session.phase, session.alpha, args.use_ALQ).detach() ex, label = utils.split_labels_out_of_latent(ex) # print("\n\n\n EX IS") #print(label) #gex = generator(ex, label, session.phase, session.alpha).detach() return ex #gex.data[:]
def KL_of_encoded_G_output(generator, z): utils.populate_z(z, args.nz + args.n_label, args.noise, batch_size(reso)) z, label = utils.split_labels_out_of_latent(z) fake = generator(z, label, session.phase, session.alpha) egz = encoder(fake, session.phase, session.alpha, args.use_ALQ) # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g return egz, label, KL_minimizer( egz) * args.fake_G_KL_scale, z
def reconstruct_from_latent_presentation(input, generator, session): # ex, label = utils.split_labels_out_of_latent(input) #input=torch.randn(args.n_label * 1, args.nz) myz = Variable(input).cuda() myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) new_img = generator(myz, input_class, session.phase, session.alpha).detach().data.cpu() #gex = generator(input, label, session.phase, session.alpha).detach() return new_img
def D_prediction_of_G_output(generator, encoder, step, alpha): # To use labels, enable here and elsewhere: #label = Variable(torch.ones(batch_size_by_phase(step), args.n_label)).cuda() # label = Variable( # torch.multinomial( # torch.ones(args.n_label), args.batch_size, replacement=True)).cuda() myz = Variable(torch.randn(batch_size_by_phase(step), args.nz)).cuda(async=(args.gpu_count > 1)) myz = utils.normalize(myz) myz, label = utils.split_labels_out_of_latent(myz) fake_image = generator(myz, label, step, alpha) fake_predict, _ = encoder(fake_image, step, alpha, args.use_ALQ) loss = fake_predict.mean() return loss, fake_image
def make(generator, session, writer): import bcolz if not os.path.exists(args.x_outpath): os.makedirs(args.x_outpath) utils.requires_grad(generator, False) z = bcolz.carray(rootdir=args.z_inpath) print("Loading external z from {}, shape:".format(args.z_inpath)) print(np.shape(z)) N = np.shape(z)[0] samplesRepeatN = 1 + int(N / 8) #Assume large files... samplesDone = 0 for outer_count in range(samplesRepeatN): samplesN = min(8, N - samplesDone) if samplesN <= 0: break zrange = z[samplesDone:samplesDone+samplesN, :] myz, input_class = utils.split_labels_out_of_latent(torch.from_numpy(zrange.astype(np.float32))) new_imgs = generator( myz, input_class, session.phase, session.alpha).detach() #.data.cpu() for ii, img in enumerate(new_imgs): torchvision.utils.save_image( img, '{}/{}.png'.format(args.x_outpath, str(ii + outer_count*8)), #.zfill(6) nrow=args.n_label, normalize=True, range=(-1, 1), padding=0) samplesDone += samplesN
def train(generator, encoder, g_running, train_data_loader, test_data_loader, session, total_steps, train_mode): pbar = tqdm(initial=session.sample_i, total=total_steps) benchmarking = False match_x = args.match_x generatedImagePool = None refresh_dataset = True refresh_imagePool = True # After the Loading stage, we cycle through successive Fade-in and Stabilization stages batch_count = 0 reset_optimizers_on_phase_start = False # TODO Unhack this (only affects the episode count statistics anyway): if args.data != 'celebaHQ': epoch_len = len(train_data_loader(1, 4).dataset) else: epoch_len = train_data_loader._len['data4x4'] if args.step_offset != 0: if args.step_offset == -1: args.step_offset = session.sample_i print("Step offset is {}".format(args.step_offset)) session.phase += args.phase_offset session.alpha = 0.0 while session.sample_i < total_steps: ####################### Phase Maintenance ####################### steps_in_previous_phases = max(session.phase * args.images_per_stage, args.step_offset) sample_i_current_stage = session.sample_i - steps_in_previous_phases # If we can move to the next phase if sample_i_current_stage >= args.images_per_stage: if session.phase < args.max_phase: # If any phases left iteration_levels = int(sample_i_current_stage / args.images_per_stage) session.phase += iteration_levels sample_i_current_stage -= iteration_levels * args.images_per_stage match_x = args.match_x # Reset to non-matching phase print( "iteration B alpha={} phase {} will be reduced to 1 and [max]" .format(sample_i_current_stage, session.phase)) refresh_dataset = True refresh_imagePool = True # Reset the pool to avoid images of 2 different resolutions in the pool if reset_optimizers_on_phase_start: utils.requires_grad(generator) utils.requires_grad(encoder) generator.zero_grad() encoder.zero_grad() session.reset_opt() print("Optimizers have been reset.") reso = 4 * 2**session.phase # If we can switch from fade-training to stable-training if sample_i_current_stage >= args.images_per_stage / 2: if session.alpha < 1.0: refresh_dataset = True # refresh dataset generator since no longer have to fade match_x = args.match_x * args.matching_phase_x else: match_x = args.match_x session.alpha = min( 1, sample_i_current_stage * 2.0 / args.images_per_stage ) # For 100k, it was 0.00002 = 2.0 / args.images_per_stage if refresh_dataset: train_dataset = data.Utils.sample_data2(train_data_loader, batch_size(reso), reso, session) refresh_dataset = False print("Refreshed dataset. Alpha={} and iteration={}".format( session.alpha, sample_i_current_stage)) if refresh_imagePool: imagePoolSize = 200 if reso < 256 else 100 generatedImagePool = utils.ImagePool( imagePoolSize ) #Reset the pool to avoid images of 2 different resolutions in the pool refresh_imagePool = False print('Image pool created with size {} because reso is {}'.format( imagePoolSize, reso)) ####################### Training init ####################### z = utils.conditional_to_cuda( Variable(torch.FloatTensor(batch_size(reso), args.nz, 1, 1)), args.gpu_count > 1) KL_minimizer = KLN01Loss(direction=args.KL, minimize=True) KL_maximizer = KLN01Loss(direction=args.KL, minimize=False) stats = {} #one = torch.FloatTensor([1]).cuda(async=(args.gpu_count>1)) one = torch.FloatTensor([1]).cuda(non_blocking=(args.gpu_count > 1)) try: real_image, _ = next(train_dataset) except (OSError, StopIteration): train_dataset = data.Utils.sample_data2(train_data_loader, batch_size(reso), reso, session) real_image, _ = next(train_dataset) ####################### DISCRIMINATOR / ENCODER ########################### utils.switch_grad_updates_to_first_of(encoder, generator) encoder.zero_grad() #x = Variable(real_image).cuda(async=(args.gpu_count>1)) x = Variable(real_image).cuda(non_blocking=(args.gpu_count > 1)) kls = "" if train_mode == config.MODE_GAN: # Discriminator for real samples real_predict, _ = encoder(x, session.phase, session.alpha, args.use_ALQ) real_predict = real_predict.mean() \ - 0.001 * (real_predict ** 2).mean() real_predict.backward(-one) # Towards 1 # (1) Generator => D. Identical to (2) see below fake_predict, fake_image = D_prediction_of_G_output( generator, encoder, session.phase, session.alpha) fake_predict.backward(one) # Grad penalty grad_penalty = get_grad_penalty(encoder, x, fake_image, session.phase, session.alpha) grad_penalty.backward() elif train_mode == config.MODE_CYCLIC: e_losses = [] # e(X) real_z = encoder(x, session.phase, session.alpha, args.use_ALQ) if args.use_real_x_KL: # KL_real: - \Delta( e(X) , Z ) -> max_e KL_real = KL_minimizer(real_z) * args.real_x_KL_scale e_losses.append(KL_real) stats['real_mean'] = KL_minimizer.samples_mean.data.mean() stats['real_var'] = KL_minimizer.samples_var.data.mean() stats['KL_real'] = KL_real #FIXME .data[0] kls = "{0:.3f}".format(stats['KL_real']) # The final entries are the label. Normal case, just 1. Extract it/them, and make it [b x 1]: real_z, label = utils.split_labels_out_of_latent(real_z) recon_x = generator(real_z, label, session.phase, session.alpha) if args.use_loss_x_reco: # match_x: E_x||g(e(x)) - x|| -> min_e err = utils.mismatch(recon_x, x, args.match_x_metric) * match_x e_losses.append(err) stats['x_reconstruction_error'] = err #FIXME .data[0] args.use_wpgan_grad_penalty = False grad_penalty = 0.0 if args.use_loss_fake_D_KL: # TODO: The following codeblock is essentially the same as the KL_minimizer part on G side. Unify utils.populate_z(z, args.nz + args.n_label, args.noise, batch_size(reso)) z, label = utils.split_labels_out_of_latent(z) fake = generator(z, label, session.phase, session.alpha).detach() if session.alpha >= 1.0: fake = generatedImagePool.query(fake.data) # e(g(Z)) egz = encoder(fake, session.phase, session.alpha, args.use_ALQ) # KL_fake: \Delta( e(g(Z)) , Z ) -> max_e KL_fake = KL_maximizer(egz) * args.fake_D_KL_scale # Added by Igor m = args.kl_margin if m > 0.0 and session.phase >= 5: KL_loss = torch.min( torch.ones_like(KL_fake) * m / 2, torch.max(-torch.ones_like(KL_fake) * m, KL_real + KL_fake)) # KL_fake is always negative with abs value typically larger than KL_real. # Hence, the sum is negative, and must be gapped so that the minimum is the negative of the margin. else: KL_loss = KL_real + KL_fake e_losses.append(KL_loss) #e_losses.append(KL_fake) stats['fake_mean'] = KL_maximizer.samples_mean.data.mean() stats['fake_var'] = KL_maximizer.samples_var.data.mean() stats['KL_fake'] = -KL_fake #FIXME .data[0] kls = "{0}/{1:.3f}".format(kls, stats['KL_fake']) if args.use_wpgan_grad_penalty: grad_penalty = get_grad_penalty(encoder, x, fake, session.phase, session.alpha) # Update e if len(e_losses) > 0: e_loss = sum(e_losses) stats['E_loss'] = e_loss.detach().cpu().float().numpy( ) #np.float32(e_loss.data) e_loss.backward() if args.use_wpgan_grad_penalty: grad_penalty.backward() stats['Grad_penalty'] = grad_penalty.data #book-keeping disc_loss_val = e_loss #FIXME .data[0] session.optimizerD.step() torch.cuda.empty_cache() ######################## GENERATOR / DECODER ############################# if (batch_count + 1) % args.n_critic == 0: utils.switch_grad_updates_to_first_of(generator, encoder) for _ in range(args.n_generator): generator.zero_grad() g_losses = [] if train_mode == config.MODE_GAN: fake_predict, _ = D_prediction_of_G_output( generator, encoder, session.phase, session.alpha) loss = -fake_predict g_losses.append(loss) elif train_mode == config.MODE_CYCLIC: #TODO We push the z variable around here like idiots def KL_of_encoded_G_output(generator, z): utils.populate_z(z, args.nz + args.n_label, args.noise, batch_size(reso)) z, label = utils.split_labels_out_of_latent(z) fake = generator(z, label, session.phase, session.alpha) egz = encoder(fake, session.phase, session.alpha, args.use_ALQ) # KL_fake: \Delta( e(g(Z)) , Z ) -> min_g return egz, label, KL_minimizer( egz) * args.fake_G_KL_scale, z egz, label, kl, z = KL_of_encoded_G_output(generator, z) if args.use_loss_KL_z: g_losses.append(kl) # G minimizes this KL stats['KL(Phi(G))'] = kl #FIXME .data[0] kls = "{0}/{1:.3f}".format(kls, stats['KL(Phi(G))']) if args.use_loss_z_reco: z = torch.cat((z, label), 1) z_diff = utils.mismatch( egz, z, args.match_z_metric ) * args.match_z # G tries to make the original z and encoded z match g_losses.append(z_diff) if len(g_losses) > 0: loss = sum(g_losses) stats['G_loss'] = loss.detach().cpu().float().numpy( ) #np.float32(loss.data) loss.backward() # Book-keeping only: gen_loss_val = loss #FIXME .data[0] session.optimizerG.step() torch.cuda.empty_cache() if train_mode == config.MODE_CYCLIC: if args.use_loss_z_reco: stats[ 'z_reconstruction_error'] = z_diff #FIXME .data[0] accumulate(g_running, generator) del z, x, one, real_image, real_z, KL_real, label, recon_x, fake, egz, KL_fake, kl, z_diff if train_mode == config.MODE_CYCLIC: if args.use_TB: for key, val in stats.items(): writer.add_scalar(key, val, session.sample_i) elif batch_count % 100 == 0: print(stats) if args.use_TB: writer.add_scalar('LOD', session.phase + session.alpha, session.sample_i) ######################## Statistics ######################## b = batch_size_by_phase(session.phase) zr, xr = (stats['z_reconstruction_error'], stats['x_reconstruction_error'] ) if train_mode == config.MODE_CYCLIC else (0.0, 0.0) e = (session.sample_i / float(epoch_len)) pbar.set_description(( '{0}; it: {1}; phase: {2}; b: {3:.1f}; Alpha: {4:.3f}; Reso: {5}; E: {6:.2f}; KL(real/fake/fakeG): {7}; z-reco: {8:.2f}; x-reco {9:.3f}; real_var {10:.4f}' ).format(batch_count + 1, session.sample_i + 1, session.phase, b, session.alpha, reso, e, kls, zr, xr, stats['real_var'])) #(f'{i + 1}; it: {iteration+1}; b: {b:.1f}; G: {gen_loss_val:.5f}; D: {disc_loss_val:.5f};' # f' Grad: {grad_loss_val:.5f}; Alpha: {alpha:.3f}; Reso: {reso}; S-mean: {real_mean:.3f}; KL(real/fake/fakeG): {kls}; z-reco: {zr:.2f}')) pbar.update(batch_size(reso)) session.sample_i += batch_size(reso) # if not benchmarking else 100 batch_count += 1 ######################## Saving ######################## if batch_count % args.checkpoint_cycle == 0: for postfix in {'latest', str(session.sample_i).zfill(6)}: session.save_all('{}/{}_state'.format(args.checkpoint_dir, postfix)) print("Checkpointed to {}".format(session.sample_i)) ######################## Tests ######################## try: evaluate.tests_run( g_running, encoder, test_data_loader, session, writer, reconstruction=(batch_count % 800 == 0), interpolation=(batch_count % 800 == 0), collated_sampling=(batch_count % 800 == 0), individual_sampling=( batch_count % (args.images_per_stage / batch_size(reso) / 4) == 0)) except (OSError, StopIteration): print("Skipped periodic tests due to an exception.") pbar.close()
def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None): generator.eval() encoder.eval() utils.requires_grad(generator, False) utils.requires_grad(encoder, False) nr_of_imgs = 4 # "Corners" reso = 4 * 2**session.phase if True: #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0): if session.phase < 1: dataset = data.Utils.sample_data(loader, nr_of_imgs, reso) else: dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session) real_image, _ = next(dataset) Utils.interpolation_set_x = utils.conditional_to_cuda( Variable(real_image, volatile=True)) latent_reso_hor = 8 latent_reso_ver = 8 x = Utils.interpolation_set_x z0 = encoder(Variable(x), session.phase, session.alpha, args.use_ALQ).detach() t = torch.FloatTensor( latent_reso_hor * (latent_reso_ver + 1) + nr_of_imgs, x.size(1), x.size(2), x.size(3)) t[0:nr_of_imgs] = x.data[:] special_dir = args.save_dir if not args.aux_outpath else args.aux_outpath if not os.path.exists(special_dir): os.makedirs(special_dir) for o_i in range(nr_of_imgs): single_save_path = '{}{}/interpolations_{}_{}_{}_orig_{}.png'.format( special_dir, prefix, session.phase, epoch, session.alpha, o_i) grid = torchvision.utils.save_image( x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0 ) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? # Origs on the first row here # Corners are: z0[0] ... z0[1] # . # . # z0[2] ... z0[3] delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1)) delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1)) for y_i in range(latent_reso_ver): if False: #Linear interpolation z0_x0 = z0[0] + y_i * delta_z_ver0 z0_xN = z0[1] + y_i * delta_z_verN delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1) z0_x = Variable( torch.FloatTensor(latent_reso_hor, z0_x0.size(0))) for x_i in range(latent_reso_hor): z0_x[x_i] = z0_x0 + x_i * delta_z_hor if True: #Spherical t_y = float(y_i) / (latent_reso_ver - 1) #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0))) z0_y1 = Utils.slerp(z0[0].data, z0[2].data, t_y) z0_y2 = Utils.slerp(z0[1].data, z0[3].data, t_y) z0_x = Variable( torch.FloatTensor(latent_reso_hor, z0[0].size(0))) for x_i in range(latent_reso_hor): t_x = float(x_i) / (latent_reso_hor - 1) z0_x[x_i] = Utils.slerp(z0_y1, z0_y2, t_x) z0_x, label = utils.split_labels_out_of_latent(z0_x) gex = generator(z0_x, label, session.phase, session.alpha).detach() # Recall that yi=0 is the original's row: t[(y_i + 1) * latent_reso_ver:(y_i + 2) * latent_reso_ver] = gex.data[:] for x_i in range(latent_reso_hor): single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format( special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i) grid = torchvision.utils.save_image( gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0 ) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? save_path = '{}{}/interpolations_{}_{}_{}.png'.format( special_dir, prefix, session.phase, epoch, session.alpha) grid = torchvision.utils.save_image( t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0 ) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? # Hacky but this is an easy way to rescale the images to nice big lego format: if session.phase < 4: im = Image.open(save_path) im2 = im.resize((1024, 1024)) im2.save(save_path) if writer: writer.add_images('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase) # Igor: changed add_image to add_images generator.train() encoder.train()
def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages=True): generator.eval() utils.requires_grad(generator, False) # Total number is samplesRepeatN * colN * rowN # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here. samplesRepeatN = int(args.sample_N / 128) if not collateImages else 1 reso = 4 * 2**session.phase if not collateImages: special_dir = '../metrics/{}/{}/{}'.format(args.data, reso, str(global_i).zfill(6)) while os.path.exists(special_dir): special_dir += '_' os.makedirs(special_dir) for outer_count in range(samplesRepeatN): colN = 1 if not collateImages else min( 10, int(np.ceil(args.sample_N / 4.0))) rowN = 128 if not collateImages else min( 5, int(np.ceil(args.sample_N / 4.0))) images = [] for _ in range(rowN): myz = utils.conditional_to_cuda( Variable(torch.randn(args.n_label * colN, args.nz))) myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) new_imgs = generator(myz, input_class, session.phase, session.alpha).detach().data.cpu() images.append(new_imgs) if collateImages: sample_dir = '{}/sample'.format(args.save_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) save_path = '{}/{}.png'.format(sample_dir, str(global_i + 1).zfill(6)) torchvision.utils.save_image(torch.cat(images, 0), save_path, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) # Hacky but this is an easy way to rescale the images to nice big lego format: im = Image.open(save_path) im2 = im.resize((1024, 512 if reso < 256 else 1024)) im2.save(save_path) if writer: writer.add_images( 'samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase) # Igor: changed add_image to add_images else: for ii, img in enumerate(images): torchvision.utils.save_image( img, '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii + outer_count * 128), nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) generator.train()
def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None): generator.eval() encoder.eval() with torch.no_grad(): utils.requires_grad(generator, False) utils.requires_grad(encoder, False) nr_of_imgs = 4 if not args.hexmode else 6 # "Corners" reso = session.getReso() if True: #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0): if session.getResoPhase() < 1: dataset = data.Utils.sample_data(loader, nr_of_imgs, reso) else: dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session) real_image, _ = next(dataset) Utils.interpolation_set_x = Variable(real_image, volatile=True).to(device=args.device) latent_reso_hor = 8 latent_reso_ver = 8 x = Utils.interpolation_set_x if args.hexmode: x = x[:nr_of_imgs] #Corners z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach() X = np.array([-1.7321,0.0000 ,1.7321 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-3.4641,-1.7321,0.0000 ,1.7321 ,3.4641 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-1.7321,0.0000,1.7321]) Y = np.array([-3.0000,-3.0000,-3.0000,-1.5000,-1.5000,-1.5000,-1.5000,0.0000,0.0000,0.0000,0.0000,0.0000,1.5000,1.5000,1.5000,1.5000,3.0000,3.0000,3.0000]) corner_indices = np.array([0, 2, 7, 11, 16, 18]) inter_indices = [i for i in range(19) if not i in corner_indices] edge_indices = np.array([1, 3, 6, 12, 15, 17]) #distances_to_corners = np.sqrt(np.power(X - X[corner_indices[0]], 2) + np.power(Y - Y[corner_indices[0]], 2)) distances_to_corners = [np.sqrt(np.power(X - X[corner_indices[i]], 2) + np.power(Y - Y[corner_indices[i]], 2)) for i in range(6)] weights = 1.0 - distances_to_corners / np.max(distances_to_corners) z0_x = torch.zeros((19,args.nz+args.n_label)).to(device=args.device) z0_x[corner_indices,:] = z0 z0_x[edge_indices[0], :] = Utils.slerp(z0[0], z0[1], 0.5) z0_x[edge_indices[1], :] = Utils.slerp(z0[0], z0[2], 0.5) z0_x[edge_indices[2], :] = Utils.slerp(z0[1], z0[3], 0.5) z0_x[edge_indices[3], :] = Utils.slerp(z0[2], z0[4], 0.5) z0_x[edge_indices[4], :] = Utils.slerp(z0[3], z0[5], 0.5) z0_x[edge_indices[5], :] = Utils.slerp(z0[4], z0[5], 0.5) # Linear: #z0x[inter_indices,:] = (weights.T @ z0)[inter_indices,:] z0_x[inter_indices,:] = (torch.from_numpy(weights.T.astype(np.float32)).to(device=args.device) @ z0)[inter_indices,:] z0_x = utils.normalize(z0_x) else: z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach() t = torch.FloatTensor(latent_reso_hor * (latent_reso_ver+1) + nr_of_imgs, x.size(1), x.size(2), x.size(3)) t[0:nr_of_imgs] = x.data[:] save_root = args.sample_dir if args.sample_dir != None else args.save_dir special_dir = save_root if not args.aux_outpath else args.aux_outpath if not os.path.exists(special_dir): os.makedirs(special_dir) for o_i in range(nr_of_imgs): single_save_path = '{}{}/hex_interpolations_{}_{}_{}_orig_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, o_i) grid = torchvision.utils.save_image(x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? if args.hexmode: z0_x, label = utils.split_labels_out_of_latent(z0_x) gex = generator(z0_x, label, session.phase, session.alpha).detach() for x_i in range(19): single_save_path = '{}{}/hex_interpolations_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, x_i) grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? return # Origs on the first row here # Corners are: z0[0] ... z0[1] # . # . # z0[2] ... z0[3] delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1)) delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1)) for y_i in range(latent_reso_ver): if False: #Linear interpolation z0_x0 = z0[0] + y_i * delta_z_ver0 z0_xN = z0[1] + y_i * delta_z_verN delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1) z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0_x0.size(0))) for x_i in range(latent_reso_hor): z0_x[x_i] = z0_x0 + x_i * delta_z_hor if True: #Spherical t_y = float(y_i) / (latent_reso_ver-1) #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0))) z0_y1 = Utils.slerp(z0[0].data.cpu().numpy(), z0[2].data.cpu().numpy(), t_y) z0_y2 = Utils.slerp(z0[1].data.cpu().numpy(), z0[3].data.cpu().numpy(), t_y) z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0[0].size(0))) for x_i in range(latent_reso_hor): t_x = float(x_i) / (latent_reso_hor-1) z0_x[x_i] = torch.from_numpy( Utils.slerp(z0_y1, z0_y2, t_x) ) z0_x, label = utils.split_labels_out_of_latent(z0_x) gex = generator(z0_x, label, session.phase, session.alpha).detach() # Recall that yi=0 is the original's row: t[(y_i+1) * latent_reso_ver:(y_i+2)* latent_reso_ver] = gex.data[:] for x_i in range(latent_reso_hor): single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i) grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? save_path = '{}{}/interpolations_{}_{}_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha) grid = torchvision.utils.save_image(t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? # Hacky but this is an easy way to rescale the images to nice big lego format: if session.getResoPhase() < 4: im = Image.open(save_path) im2 = im.resize((1024, 1024)) im2.save(save_path) #if writer: # writer.add_image('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase) generator.train() encoder.train()
def eval_metrics(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session, writer=None): # of the form"/[dir]" generator.eval() encoder.eval() with torch.no_grad(): utils.requires_grad(generator, False) utils.requires_grad(encoder, False) if reconstructions and nr_of_imgs > 0: reso = session.getReso() batch_size = 16 n_batches = (nr_of_imgs+batch_size-1) // batch_size n_used_imgs = n_batches * batch_size #fid_score = 0 lpips_score = 0 lpips_model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=False) real_images = [] reco_images = [] rand_images = [] FIDm = 3 # FID multiplier if session.getResoPhase() >= 3: for o in range(n_batches*FIDm): myz = Variable(torch.randn(args.n_label * batch_size, args.nz)).to(device=args.device) myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) random_image = generator( myz, input_class, session.phase, session.alpha).detach().data.cpu() rand_images.append(random_image) if o >= n_batches: #Reconstructions only up to n_batches, FIDs will take n_batches * 3 samples continue dataset = data.Utils.sample_data2(loader, batch_size, reso, session) real_image, _ = next(dataset) reco_image = Utils.reconstruct(real_image, encoder, generator, session) t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1), real_image.size(2), real_image.size(3)) # compute metrics and write it to tensorboard (if the reso >= 32) #lpips_score = "?" crop_needed = (args.data == 'celeba' or args.data == 'celebaHQ' or args.data == 'ffhq') if session.getResoPhase() >= 4 or not crop_needed: # 32x32 is minimum for LPIPS if crop_needed: real_image = Utils.face_as_cropped(real_image) reco_image = Utils.face_as_cropped(reco_image) real_images.append(real_image) reco_images.append(reco_image.detach().data.cpu()) real_images = torch.cat(real_images, 0) reco_images = torch.cat(reco_images, 0) rand_images = torch.cat(rand_images, 0) lpips_dist = lpips_model.forward(real_images, reco_images) lpips_score = torch.mean(lpips_dist) fid_score = calculate_fid_given_images(real_images, rand_images, batch_size=batch_size, cuda=False, dims=2048) writer.add_scalar('LPIPS', lpips_score, session.sample_i) writer.add_scalar('FID', fid_score, session.sample_i) print("{}: FID = {}, LPIPS = {}".format(session.sample_i, fid_score, lpips_score)) encoder.train() generator.train()