def reconstruct_from_latent_presentation(input, generator, session): # ex, label = utils.split_labels_out_of_latent(input) #input=torch.randn(args.n_label * 1, args.nz) myz = Variable(input).cuda() myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) new_img = generator(myz, input_class, session.phase, session.alpha).detach().data.cpu() #gex = generator(input, label, session.phase, session.alpha).detach() return new_img
def D_prediction_of_G_output(generator, encoder, step, alpha): # To use labels, enable here and elsewhere: #label = Variable(torch.ones(batch_size_by_phase(step), args.n_label)).cuda() # label = Variable( # torch.multinomial( # torch.ones(args.n_label), args.batch_size, replacement=True)).cuda() myz = Variable(torch.randn(batch_size_by_phase(step), args.nz)).cuda(async=(args.gpu_count > 1)) myz = utils.normalize(myz) myz, label = utils.split_labels_out_of_latent(myz) fake_image = generator(myz, label, step, alpha) fake_predict, _ = encoder(fake_image, step, alpha, args.use_ALQ) loss = fake_predict.mean() return loss, fake_image
def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages=True): generator.eval() utils.requires_grad(generator, False) # Total number is samplesRepeatN * colN * rowN # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here. samplesRepeatN = int(args.sample_N / 128) if not collateImages else 1 reso = 4 * 2**session.phase if not collateImages: special_dir = '../metrics/{}/{}/{}'.format(args.data, reso, str(global_i).zfill(6)) while os.path.exists(special_dir): special_dir += '_' os.makedirs(special_dir) for outer_count in range(samplesRepeatN): colN = 1 if not collateImages else min( 10, int(np.ceil(args.sample_N / 4.0))) rowN = 128 if not collateImages else min( 5, int(np.ceil(args.sample_N / 4.0))) images = [] for _ in range(rowN): myz = utils.conditional_to_cuda( Variable(torch.randn(args.n_label * colN, args.nz))) myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) new_imgs = generator(myz, input_class, session.phase, session.alpha).detach().data.cpu() images.append(new_imgs) if collateImages: sample_dir = '{}/sample'.format(args.save_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) save_path = '{}/{}.png'.format(sample_dir, str(global_i + 1).zfill(6)) torchvision.utils.save_image(torch.cat(images, 0), save_path, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) # Hacky but this is an easy way to rescale the images to nice big lego format: im = Image.open(save_path) im2 = im.resize((1024, 512 if reso < 256 else 1024)) im2.save(save_path) if writer: writer.add_images( 'samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase) # Igor: changed add_image to add_images else: for ii, img in enumerate(images): torchvision.utils.save_image( img, '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii + outer_count * 128), nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) generator.train()
def generate_intermediate_samples(generator, global_i, session, writer=None, collateImages = True): generator.eval() with torch.no_grad(): save_root = args.sample_dir if args.sample_dir != None else args.save_dir utils.requires_grad(generator, False) # Total number is samplesRepeatN * colN * rowN # e.g. for 51200 samples, outcome is 5*80*128. Please only give multiples of 128 here. samplesRepeatN = max(1, int(args.sample_N / 128)) if not collateImages else 1 reso = session.getReso() if not collateImages: special_dir = '{}/sample/{}'.format(save_root, str(global_i).zfill(6)) while os.path.exists(special_dir): special_dir += '_' os.makedirs(special_dir) for outer_count in range(samplesRepeatN): colN = 1 if not collateImages else min(10, int(np.ceil(args.sample_N / 4.0))) rowN = 128 if not collateImages else min(5, int(np.ceil(args.sample_N / 4.0))) images = [] myzs = [] for _ in range(rowN): myz = {} for i in range(2): myz0 = Variable(torch.randn(args.n_label * colN, args.nz+1)).to(device=args.device) myz[i] = utils.normalize(myz0) #myz[i], input_class = utils.split_labels_out_of_latent(myz0) if args.randomMixN <= 1: new_imgs = generator( myz[0], None, # input_class, session.phase, session.alpha).detach() #.data.cpu() else: max_i = session.phase style_layer_begin = 2 #np.random.randint(low=1, high=(max_i+1)) print("Cut at layer {} for the next {} random samples.".format(style_layer_begin, (myz[0].shape))) new_imgs = utils.gen_seq([ (myz[0], 0, style_layer_begin), (myz[1], style_layer_begin, -1) ], generator, session).detach() images.append(new_imgs) myzs.append(myz[0]) # TODO: Support mixed latents with rejection sampling if args.rejection_sampling > 0 and not collateImages: filtered_images = [] batch_size = 8 for b in range(int(len(images) / batch_size)): img_batch = images[b*batch_size:(b+1)*batch_size] reco_z = session.encoder(Variable(torch.cat(img_batch,0)), session.getResoPhase(), session.alpha, args.use_ALQ).detach() cost_z = utils.mismatchV(torch.cat(myzs[b*batch_size:(b+1)*batch_size],0), reco_z, 'cos') _, ii = torch.sort(cost_z) keeper = args.rejection_sampling / 100.0 keepN = max(1, int(len(ii)*keeper)) for iindex in ii[:keepN]: filtered_images.append(img_batch[iindex]) # filtered_images.append(img_batch[ii[:keepN]] ) #retain the best ones only images = filtered_images if collateImages: sample_dir = '{}/sample'.format(save_root) if not os.path.exists(sample_dir): os.makedirs(sample_dir) save_path = '{}/{}.png'.format(sample_dir,str(global_i + 1).zfill(6)) torchvision.utils.save_image( torch.cat(images, 0), save_path, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) # Hacky but this is an easy way to rescale the images to nice big lego format: im = Image.open(save_path) im2 = im.resize((1024, 512 if reso < 256 else 1024)) im2.save(save_path) #if writer: # writer.add_image('samples_latest_{}'.format(session.phase), torch.cat(images, 0), session.phase) else: print("Generating samples at {}...".format(special_dir)) for ii, img in enumerate(images): ipath = '{}/{}_{}.png'.format(special_dir, str(global_i + 1).zfill(6), ii+outer_count*128) #print(ipath) torchvision.utils.save_image( img, ipath, nrow=args.n_label * colN, normalize=True, range=(-1, 1), padding=0) generator.train()
def interpolate_images(generator, encoder, loader, epoch, prefix, session, writer=None): generator.eval() encoder.eval() with torch.no_grad(): utils.requires_grad(generator, False) utils.requires_grad(encoder, False) nr_of_imgs = 4 if not args.hexmode else 6 # "Corners" reso = session.getReso() if True: #if Utils.interpolation_set_x is None or Utils.interpolation_set_x.size(2) != reso or (phase >= 1 and alpha < 1.0): if session.getResoPhase() < 1: dataset = data.Utils.sample_data(loader, nr_of_imgs, reso) else: dataset = data.Utils.sample_data2(loader, nr_of_imgs, reso, session) real_image, _ = next(dataset) Utils.interpolation_set_x = Variable(real_image, volatile=True).to(device=args.device) latent_reso_hor = 8 latent_reso_ver = 8 x = Utils.interpolation_set_x if args.hexmode: x = x[:nr_of_imgs] #Corners z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach() X = np.array([-1.7321,0.0000 ,1.7321 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-3.4641,-1.7321,0.0000 ,1.7321 ,3.4641 ,-2.5981,-0.8660,0.8660 ,2.5981 ,-1.7321,0.0000,1.7321]) Y = np.array([-3.0000,-3.0000,-3.0000,-1.5000,-1.5000,-1.5000,-1.5000,0.0000,0.0000,0.0000,0.0000,0.0000,1.5000,1.5000,1.5000,1.5000,3.0000,3.0000,3.0000]) corner_indices = np.array([0, 2, 7, 11, 16, 18]) inter_indices = [i for i in range(19) if not i in corner_indices] edge_indices = np.array([1, 3, 6, 12, 15, 17]) #distances_to_corners = np.sqrt(np.power(X - X[corner_indices[0]], 2) + np.power(Y - Y[corner_indices[0]], 2)) distances_to_corners = [np.sqrt(np.power(X - X[corner_indices[i]], 2) + np.power(Y - Y[corner_indices[i]], 2)) for i in range(6)] weights = 1.0 - distances_to_corners / np.max(distances_to_corners) z0_x = torch.zeros((19,args.nz+args.n_label)).to(device=args.device) z0_x[corner_indices,:] = z0 z0_x[edge_indices[0], :] = Utils.slerp(z0[0], z0[1], 0.5) z0_x[edge_indices[1], :] = Utils.slerp(z0[0], z0[2], 0.5) z0_x[edge_indices[2], :] = Utils.slerp(z0[1], z0[3], 0.5) z0_x[edge_indices[3], :] = Utils.slerp(z0[2], z0[4], 0.5) z0_x[edge_indices[4], :] = Utils.slerp(z0[3], z0[5], 0.5) z0_x[edge_indices[5], :] = Utils.slerp(z0[4], z0[5], 0.5) # Linear: #z0x[inter_indices,:] = (weights.T @ z0)[inter_indices,:] z0_x[inter_indices,:] = (torch.from_numpy(weights.T.astype(np.float32)).to(device=args.device) @ z0)[inter_indices,:] z0_x = utils.normalize(z0_x) else: z0 = encoder(Variable(x), session.getResoPhase(), session.alpha, args.use_ALQ).detach() t = torch.FloatTensor(latent_reso_hor * (latent_reso_ver+1) + nr_of_imgs, x.size(1), x.size(2), x.size(3)) t[0:nr_of_imgs] = x.data[:] save_root = args.sample_dir if args.sample_dir != None else args.save_dir special_dir = save_root if not args.aux_outpath else args.aux_outpath if not os.path.exists(special_dir): os.makedirs(special_dir) for o_i in range(nr_of_imgs): single_save_path = '{}{}/hex_interpolations_{}_{}_{}_orig_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, o_i) grid = torchvision.utils.save_image(x.data[o_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? if args.hexmode: z0_x, label = utils.split_labels_out_of_latent(z0_x) gex = generator(z0_x, label, session.phase, session.alpha).detach() for x_i in range(19): single_save_path = '{}{}/hex_interpolations_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, x_i) grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? return # Origs on the first row here # Corners are: z0[0] ... z0[1] # . # . # z0[2] ... z0[3] delta_z_ver0 = ((z0[2] - z0[0]) / (latent_reso_ver - 1)) delta_z_verN = ((z0[3] - z0[1]) / (latent_reso_ver - 1)) for y_i in range(latent_reso_ver): if False: #Linear interpolation z0_x0 = z0[0] + y_i * delta_z_ver0 z0_xN = z0[1] + y_i * delta_z_verN delta_z_hor = (z0_xN - z0_x0) / (latent_reso_hor - 1) z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0_x0.size(0))) for x_i in range(latent_reso_hor): z0_x[x_i] = z0_x0 + x_i * delta_z_hor if True: #Spherical t_y = float(y_i) / (latent_reso_ver-1) #z0_y = Variable(torch.FloatTensor(latent_reso_ver, z0.size(0))) z0_y1 = Utils.slerp(z0[0].data.cpu().numpy(), z0[2].data.cpu().numpy(), t_y) z0_y2 = Utils.slerp(z0[1].data.cpu().numpy(), z0[3].data.cpu().numpy(), t_y) z0_x = Variable(torch.FloatTensor(latent_reso_hor, z0[0].size(0))) for x_i in range(latent_reso_hor): t_x = float(x_i) / (latent_reso_hor-1) z0_x[x_i] = torch.from_numpy( Utils.slerp(z0_y1, z0_y2, t_x) ) z0_x, label = utils.split_labels_out_of_latent(z0_x) gex = generator(z0_x, label, session.phase, session.alpha).detach() # Recall that yi=0 is the original's row: t[(y_i+1) * latent_reso_ver:(y_i+2)* latent_reso_ver] = gex.data[:] for x_i in range(latent_reso_hor): single_save_path = '{}{}/interpolations_{}_{}_{}_{}x{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha, y_i, x_i) grid = torchvision.utils.save_image(gex.data[x_i] / 2 + 0.5, single_save_path, nrow=1, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? save_path = '{}{}/interpolations_{}_{}_{}.png'.format(special_dir, prefix, session.phase, epoch, session.alpha) grid = torchvision.utils.save_image(t / 2 + 0.5, save_path, nrow=latent_reso_ver, padding=0) #, normalize=True) #range=(-1,1)) #, normalize=True) #, scale_each=True)? # Hacky but this is an easy way to rescale the images to nice big lego format: if session.getResoPhase() < 4: im = Image.open(save_path) im2 = im.resize((1024, 1024)) im2.save(save_path) #if writer: # writer.add_image('interpolation_latest_{}'.format(session.phase), t / 2 + 0.5, session.phase) generator.train() encoder.train()
def eval_metrics(generator, encoder, loader, global_i, nr_of_imgs, prefix, reals, reconstructions, session, writer=None): # of the form"/[dir]" generator.eval() encoder.eval() with torch.no_grad(): utils.requires_grad(generator, False) utils.requires_grad(encoder, False) if reconstructions and nr_of_imgs > 0: reso = session.getReso() batch_size = 16 n_batches = (nr_of_imgs+batch_size-1) // batch_size n_used_imgs = n_batches * batch_size #fid_score = 0 lpips_score = 0 lpips_model = lpips.PerceptualLoss(model='net-lin', net='alex', use_gpu=False) real_images = [] reco_images = [] rand_images = [] FIDm = 3 # FID multiplier if session.getResoPhase() >= 3: for o in range(n_batches*FIDm): myz = Variable(torch.randn(args.n_label * batch_size, args.nz)).to(device=args.device) myz = utils.normalize(myz) myz, input_class = utils.split_labels_out_of_latent(myz) random_image = generator( myz, input_class, session.phase, session.alpha).detach().data.cpu() rand_images.append(random_image) if o >= n_batches: #Reconstructions only up to n_batches, FIDs will take n_batches * 3 samples continue dataset = data.Utils.sample_data2(loader, batch_size, reso, session) real_image, _ = next(dataset) reco_image = Utils.reconstruct(real_image, encoder, generator, session) t = torch.FloatTensor(real_image.size(0) * 2, real_image.size(1), real_image.size(2), real_image.size(3)) # compute metrics and write it to tensorboard (if the reso >= 32) #lpips_score = "?" crop_needed = (args.data == 'celeba' or args.data == 'celebaHQ' or args.data == 'ffhq') if session.getResoPhase() >= 4 or not crop_needed: # 32x32 is minimum for LPIPS if crop_needed: real_image = Utils.face_as_cropped(real_image) reco_image = Utils.face_as_cropped(reco_image) real_images.append(real_image) reco_images.append(reco_image.detach().data.cpu()) real_images = torch.cat(real_images, 0) reco_images = torch.cat(reco_images, 0) rand_images = torch.cat(rand_images, 0) lpips_dist = lpips_model.forward(real_images, reco_images) lpips_score = torch.mean(lpips_dist) fid_score = calculate_fid_given_images(real_images, rand_images, batch_size=batch_size, cuda=False, dims=2048) writer.add_scalar('LPIPS', lpips_score, session.sample_i) writer.add_scalar('FID', fid_score, session.sample_i) print("{}: FID = {}, LPIPS = {}".format(session.sample_i, fid_score, lpips_score)) encoder.train() generator.train()