def save_samples(self, i, model, step, alpha, resolution, fid, run_id): images = [] # camera_params = camera_dynamic((resolution, resolution), self.cam_t) flength = 5000 cam_t = np.array([0., 0., 0]) camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength) with torch.no_grad(): for img_idx in range(self.gen_i): flame_param_this_batch = self.sampling_flame_labels[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] if self.pose is not None: pose_this_batch = self.pose[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] else: pose_this_batch = None idx_this_batch = self.input_indices[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] img_tensor = model(flame_param_this_batch.clone(), pose_this_batch, step=step, alpha=alpha, input_indices=idx_this_batch)[-1] img_tensor = self.overlay_visualizer.range_normalize_images( dataset_loaders.fast_image_reshape(img_tensor, height_out=256, width_out=256, non_diff_allowed=True)) images.append(img_tensor.data.cpu()) torchvision.utils.save_image( torch.cat(images, 0), f'{cnst.output_root}sample/{str(run_id)}/{str(i + 1).zfill(6)}_res{resolution}x{resolution}_fid_{fid:.2f}.png', nrow=self.gen_i, normalize=True, range=(0, 1))
def pairwise_texture_loss(self, tx1, tx2): if self.face_region_only_mask.device != tx1.device: self.face_region_only_mask = self.face_region_only_mask.to( tx1.device) if self.face_region_only_mask.shape[-1] != tx1.shape[-1]: face_region_only_mask = fast_image_reshape( self.face_region_only_mask, tx1.shape[1], tx1.shape[2]) else: face_region_only_mask = self.face_region_only_mask # import ipdb; ipdb.set_trace() return torch.mean( torch.sigmoid(torch.pow(tx1 - tx2, 2)) * face_region_only_mask[0])
def get_gif_from_list_of_params(generator, flame_params, step, alpha, noise, overlay_landmarks, flame_std, flame_mean, overlay_visualizer, rendered_flame_as_condition, use_posed_constant_input, normal_maps_as_cond, camera_params): # cam_t = np.array([0., 0., 2.5]) # camera_params = camera_dynamic((224, 224), cam_t) if overlay_visualizer is None: overlay_visualizer = OverLayViz() fixed_embeddings = torch.ones(flame_params.shape[0], dtype=torch.long, device='cuda')*13 # print(generator.module.get_embddings()[fixed_embeddings]) flame_params_unnorm = flame_params * flame_std + flame_mean flame_params_unnorm = torch.from_numpy(flame_params_unnorm).cuda() normal_map_img, _, _, _, rend_imgs = \ overlay_visualizer.get_rendered_mesh(flame_params=(flame_params_unnorm[:, SHAPE_IDS[0]:SHAPE_IDS[1]], flame_params_unnorm[:, EXP_IDS[0]:EXP_IDS[1]], flame_params_unnorm[:, POSE_IDS[0]:POSE_IDS[1]], flame_params_unnorm[:, TRANS_IDS[0]:TRANS_IDS[1]]), camera_params=camera_params) rend_imgs = (rend_imgs/127.0 - 1) if use_posed_constant_input: pose = flame_params[:, constants.get_idx_list('GLOBAL_ROT')] else: pose = None if rendered_flame_as_condition: gen_in = rend_imgs else: gen_in = flame_params if normal_maps_as_cond: gen_in = torch.cat((rend_imgs, normal_map_img), dim=1) fake_images = generate_from_flame_sequence(generator, gen_in, pose, step, alpha, noise, input_indices=fixed_embeddings)[-1] fake_images = overlay_visualizer.range_normalize_images(fast_image_reshape(fake_images, height_out=rend_imgs.shape[2], width_out=rend_imgs.shape[3])) if rendered_flame_as_condition: fake_images = torch.cat([fake_images.cpu(), (rend_imgs.cpu() + 1)/2], dim=-1) if normal_maps_as_cond: fake_images = torch.cat([fake_images.cpu(), (normal_map_img.cpu() + 1) / 2], dim=-1) return fake_images
exp = flm_batch[:, constants.INDICES['EXP'][0]:constants. INDICES['EXP'][1]] pose = flm_batch[:, constants.INDICES['POSE'][0]:constants. INDICES['POSE'][1]] # import ipdb; ipdb.set_trace() light_code = \ flm_batch[:, constants.DECA_IDX['lit'][0]:constants.DECA_IDX['lit'][1]:].view((batch_size_true, 9, 3)) texture_code = flm_batch[:, constants.DECA_IDX['tex'][0]:constants. DECA_IDX['tex'][1]:] norma_map_img, _, _, _, rend_flm = \ overlay_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code), camera_params=cam) rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1 norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1 rend_flm = fast_image_reshape(rend_flm, height_out=256, width_out=256, mode='bilinear') norma_map_img = fast_image_reshape(norma_map_img, height_out=256, width_out=256, mode='bilinear') else: rend_flm = None norma_map_img = None gen_1_in = ge_gen_in( flm_batch, rend_flm, norma_map_img, settings_for_runs[run_idx]['normal_maps_as_cond'], settings_for_runs[run_idx]['rendered_flame_as_condition'])
def get_image_and_textures(self, alpha, flame_batch, generator, max_ids, normal_maps_as_cond, rendered_flame_as_condition, step, use_posed_constant_input): batch_size = flame_batch.shape[0] flame_batch = flame_batch[:self.max_num, :] # Just to limit run time # import ipdb; ipdb.set_trace() if rendered_flame_as_condition or normal_maps_as_cond: shape = flame_batch[:, constants.INDICES['SHAPE'][0]:constants. INDICES['SHAPE'][1]] exp = flame_batch[:, constants.INDICES['EXP'][0]:constants. INDICES['EXP'][1]] pose = flame_batch[:, constants.INDICES['POSE'][0]:constants. INDICES['POSE'][1]] if flame_batch.shape[-1] == 159: # non DECA params flame_params = ( shape, exp, pose, flame_batch[:, constants.INDICES['TRANS'][0]:constants. INDICES['TRANS'][1]]) # translation norma_map_img, _, _, _, rend_flm = self.flame_visualizer.get_rendered_mesh( flame_params=flame_params, camera_params=self.camera_params) rend_flm = rend_flm / 127 - 1.0 norma_map_img = norma_map_img * 2 - 1 elif flame_batch.shape[-1] == 236: # DECA cam = flame_batch[:, constants.DECA_IDX['cam'][0]:constants. DECA_IDX['cam'][1]:] # Same lightcode for the whole batch light_code = flame_batch[ 0:1, constants.DECA_IDX['lit'][0]:constants.DECA_IDX['lit'][1]:] light_code = light_code.repeat(batch_size, 1) light_code = light_code.view((batch_size, 9, 3)) #same texture code for the whole batch texture_code = flame_batch[ 0:1, constants.DECA_IDX['tex'][0]:constants.DECA_IDX['tex'][1]:] texture_code = texture_code.repeat(batch_size, 1) # import ipdb; ipdb.set_trace() norma_map_img, _, _, _, rend_flm = \ self.flame_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code), camera_params=cam) # rend_flm = rend_flm * 2 - 1 rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1 norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1 rend_flm = fast_image_reshape(rend_flm, height_out=256, width_out=256, mode='bilinear') norma_map_img = fast_image_reshape(norma_map_img, height_out=256, width_out=256, mode='bilinear') else: raise ValueError('Flame prameter format not understood') # import ipdb; ipdb.set_trace() if use_posed_constant_input: pose = flame_batch[:, constants.get_idx_list('GLOBAL_ROT')] else: pose = None fixed_identities = torch.ones( flame_batch.shape[0], dtype=torch.long, device='cuda') * np.random.randint(0, max_ids) if rendered_flame_as_condition and normal_maps_as_cond: gen_in = torch.cat((rend_flm, norma_map_img), dim=1) elif rendered_flame_as_condition: gen_in = rend_flm elif normal_maps_as_cond: gen_in = norma_map_img else: gen_in = flame_batch generated_image = generator(gen_in, pose=pose, step=step, alpha=alpha, input_indices=fixed_identities)[-1] textures, tx_masks = self.flm_tex_dec(generated_image, flame_batch) return textures, tx_masks, generated_image
flame_version='DECA', image_size=256) deca_flame = overlay_visualizer.deca.flame example_out_dir = '/is/cluster/work/pghosh/gif1.0/eye_cntr_images' normalized_eye_cntr_L = [] normalized_eye_cntr_R = [] for indx in range(50): # i = int(random.randint(0, 60_000)) i = indx fig, ax1 = plt.subplots(1, 1) img, flm_rndr, flm_lbl, index = dataset.__getitem__(i, bypass_valid_indexing=False) img = img[None, ...] img = fast_image_reshape(img, height_out=show_img_at_res, width_out=show_img_at_res) img = img[0] ax1.imshow((img.numpy().transpose((1, 2, 0)) + 1)/2) flame_batch = torch.from_numpy(flm_lbl[0][None, ...]).cuda() flame_batch = position_to_given_location(deca_flame, flame_batch) # import ipdb; ipdb.set_trace() shape, expression, pose = (flame_batch[:, 0:100, ], flame_batch[:, 100:150], flame_batch[:, 150:156]) vertices, l_m2d, _ = deca_flame(shape_params=shape, expression_params=expression, pose_params=pose) # l_m2d[:, :, 1] *= -1 # vertices[:, :, 1] *= -1 eye_left_3d = vertices[:, 4051] eye_right_3d = vertices[:, 4597]
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), ]) dataset = FFHQ(real_img_root=data_root, params_dir=params_dir, generic_transform=generic_transform, pose_cam_from_yao=True, resolution=128, normalization_file_path=normalization_file_path) dataset = iter(dataset) resolutions = [4 * 2**_step for _step in range(step_max + 1)] img, flm_prm = next(dataset) flm_prm = torch.from_numpy(flm_prm[0]).cuda()[None, :] real_image = img.cuda()[None, :] real_image_list = [ fast_image_reshape(real_image, height, height) for height in resolutions ] discriminator.eval() with torch.no_grad(): decissions = discriminator(in_img_list=real_image_list, condition=flm_prm, step=step_max, alpha=0.5965608333333333) import ipdb ipdb.set_trace()
image_size=256) example_out_dir = '/is/cluster/work/pghosh/gif1.0/DECA_inferred/saved_rerendered_side_by_side' for indx in range(5): # i = int(random.randint(0, 60_000)) i = indx fig, (ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8) = plt.subplots(1, 8) img, flm_rndr, flm_lbl, index = dataset.__getitem__( i, bypass_valid_indexing=False) img = img[None, ...] flm_rndr = flm_rndr[0][None, ...] img = fast_image_reshape(img, height_out=show_img_at_res, width_out=show_img_at_res) flm_rndr = fast_image_reshape(flm_rndr, height_out=show_img_at_res, width_out=show_img_at_res) img = img[0] flm_rndr = flm_rndr[0] flm_rndr = (flm_rndr.numpy().transpose((1, 2, 0)) + 1) / 2 ax1.imshow((img.numpy().transpose((1, 2, 0)) + 1) / 2) ax2.imshow(flm_rndr[:, :, :3]) if flm_rndr.shape[-1] > 3: ax3.imshow(flm_rndr[:, :, 3:])
# import ipdb; ipdb.set_trace() light_code = torch.cat((light_code, mean_minus_3_sigma_light.view(1,9, 3), mean_plu_3_sigma_light.view(1,9, 3)), dim=0) texture_code = torch.from_numpy(texture_code).cuda()[None, ...].repeat(exp.shape[0]-4, 1) texture_code_neg_3_sigma = texture_code[0:1, :] * 0 texture_code_neg_3_sigma[0, 0] -= 3 texture_code_pos_3_sigma = texture_code[0:1, :] * 0 texture_code_pos_3_sigma[0, 0] += 3 texture_code = torch.cat((texture_code, texture_code_neg_3_sigma, texture_code_pos_3_sigma, texture_code[:2, :]), dim=0) norma_map_img, _, _, _, rend_flm = \ overlay_visualizer.get_rendered_mesh(flame_params=(shape, exp, pose, light_code, texture_code), camera_params=cam) rend_flm = torch.clamp(rend_flm, 0, 1) * 2 - 1 norma_map_img = torch.clamp(norma_map_img, 0, 1) * 2 - 1 rend_flm = fast_image_reshape(rend_flm, height_out=256, width_out=256, mode='bilinear') norma_map_img = fast_image_reshape(norma_map_img, height_out=256, width_out=256, mode='bilinear') # Only for testing # import ipdb; ipdb.set_trace() # rend_flm = torch.from_numpy(np.load('../deca_test_4_flame_rendering.npy')).cuda()[:, :3, :, :] # norma_map_img = torch.from_numpy(np.load('../deca_test_4_flame_rendering.npy')).cuda()[:, 3:, :, :] # flm_rndrds_trn = np.load('../visual_batch_indices_and_flame_renderings.npz') # rend_flm = torch.from_numpy(flm_rndrds_trn['condition_parmas'][:4, :3, :, :]).cuda() # norma_map_img = torch.from_numpy(flm_rndrds_trn['condition_parmas'][:4, 3:, :, :]).cuda() if normal_maps_as_cond and rendered_flame_as_condition: # norma_map_img = norma_map_img * 2 - 1 gen_in = torch.cat((rend_flm, norma_map_img), dim=1) elif normal_maps_as_cond:
texture_code_pos_3_sigma = texture_code[0:1, :] * 0 texture_code_pos_3_sigma[0, 0] += 3 texture_code = torch.cat((texture_code, texture_code_neg_3_sigma, texture_code_pos_3_sigma, texture_code[:2, :]), dim=0) verts, landmarks2d, landmarks3d = overlay_visualizer.deca.flame(shape_params=shape, expression_params=exp, pose_params=pose) landmarks2d_projected = batch_orth_proj(landmarks2d, cam) landmarks2d_projected[:, :, 1:] *= -1 trans_verts = batch_orth_proj(verts, cam) trans_verts[:, :, 1:] = -trans_verts[:, :, 1:] right_albedos = overlay_visualizer.flametex(texture_code) # albedos = torch.tensor([47, 59, 65], dtype=torch.float32)[None, ..., None, None].cuda()/255.0*1.5 albedos = torch.tensor([0.6, 0.6, 0.6], dtype=torch.float32)[None, ..., None, None].cuda() albedos = albedos.repeat(texture_code.shape[0], 1, 512, 512) albedos[-4:] = fast_image_reshape(right_albedos[-4:], height_out=512, width_out=512) rendering_results = overlay_visualizer.deca.render(verts, trans_verts, albedos, lights=light_code, light_type='point', cull_backfaces=True) textured_images, normals, alpha_img = rendering_results['images'], rendering_results['normals'],\ rendering_results['alpha_images'] normal_images = overlay_visualizer.deca.render.render_normal(trans_verts, normals) rend_flm = torch.clamp(textured_images, 0, 1) * 2 - 1 norma_map_img = torch.clamp(normal_images, 0, 1) * 2 - 1 id_start = 20 save_dir_teaser = os.path.join(save_dir_tsr, f'images_gt_FLAME/') os.makedirs(save_dir_teaser, exist_ok=True)
pose_batch = pose[batch_idx:batch_idx+32] cam_batch = cam[batch_idx:batch_idx+32] true_batch_size = cam_batch.shape[0] light_code = light_texture_id_code_source['light_code'][id].astype('float32')[None, ...].\ repeat(true_batch_size, axis=0) texture_code = light_texture_id_code_source['texture_code'][id].astype('float32')[None, ...].\ repeat(true_batch_size, axis=0) # import ipdb; ipdb.set_trace() norma_map_img_batch, _, _, _, rend_flm_batch = \ overlay_visualizer.get_rendered_mesh(flame_params=(shape_batch, exp_batch, pose_batch, torch.from_numpy(light_code).cuda(), torch.from_numpy(texture_code).cuda()), camera_params=cam_batch) rend_flm_batch = torch.clamp(rend_flm_batch, 0, 1) * 2 - 1 norma_map_img_batch = torch.clamp(norma_map_img_batch, 0, 1) * 2 - 1 rend_flm_batch = fast_image_reshape(rend_flm_batch, height_out=256, width_out=256, mode='bilinear') norma_map_img_batch = fast_image_reshape(norma_map_img_batch, height_out=256, width_out=256, mode='bilinear') # with back face culling and white texture norma_map_img_batch_to_save, _, _, _, rend_flm_batch_to_save = \ overlay_visualizer.get_rendered_mesh(flame_params=(shape_batch, exp_batch, pose_batch, torch.from_numpy(light_code).cuda(), torch.from_numpy(texture_code).cuda()), camera_params=cam_batch, cull_backfaces=True, constant_albedo=0.6) rend_flm_batch_to_save = torch.clamp(rend_flm_batch_to_save, 0, 1) * 2 - 1 norma_map_img_batch_to_save = torch.clamp(norma_map_img_batch_to_save, 0, 1) * 2 - 1 rend_flm_batch_to_save = fast_image_reshape(rend_flm_batch_to_save, height_out=256, width_out=256, mode='bilinear') norma_map_img_batch_to_save = fast_image_reshape(norma_map_img_batch_to_save, height_out=256, width_out=256, mode='bilinear') gen_in = torch.cat((rend_flm_batch, norma_map_img_batch), dim=1)
# i = 23650 fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5) print(f'getting_img{i}') img, flm_rndr, flm_lbl, index = dataset.__getitem__( i, bypass_valid_indexing=True) # if flm_rndr[0].max() < 0.1: # import ipdb; ipdb.set_trace() img = img[None, ...] flm_rndr = flm_rndr[0][None, ...] if with_blur: img_blr = blur_module(img) flm_rndr_blr = blur_module(flm_rndr) img_blr = fast_image_reshape(img_blr, height_out=show_img_at_res, width_out=show_img_at_res) flm_rndr_blr = fast_image_reshape(flm_rndr_blr, height_out=show_img_at_res, width_out=show_img_at_res) img_blr = img_blr[0] flm_rndr_blr = flm_rndr_blr[0] img = fast_image_reshape(img, height_out=show_img_at_res, width_out=show_img_at_res) flm_rndr = fast_image_reshape(flm_rndr, height_out=show_img_at_res, width_out=show_img_at_res)
def train(args, dataset, generator, discriminator_flm, fid_computer, flame_param_est, used_samples, step): run_avg_rate = 0.999 interp_loss_run_avg = 0 pose = None if args.embedding_vocab_size != 1: true_embeddings = generator.module.get_embddings() if args.gen_reg_type.upper() == 'PATH_LEN_REG': pl_reg = losses.PathLengthRegularizor() fid = np.nan if flame_param_est is None: fake_flame = None fake_indices = None else: fake_flame = flame_param_est.get_samples(n_samples=50, shuffle=True, normalize=True) fake_flame = torch.from_numpy(fake_flame).cuda() fake_indices = torch.from_numpy( np.random.randint(0, args.embedding_vocab_size, size=fake_flame.shape[0])).cuda() viz_saver = generic_utils.VisualizationSaver( gen_i=10, gen_j=5, sampling_flame_labels=fake_flame, dataset=dataset, input_indices=fake_indices) resolution = 4 * 2**step resolutions = [4 * 2**_step for _step in range(step + 1)] loader = sample_data(dataset, args.batch.get(resolution, args.batch_default), resolutions, debug=args.debug) data_loader = iter(loader) if args.apply_texture_space_interpolation_loss: interp_tex_loss = losses.InterpolatedTextureLoss( max_images_in_batch=args.batch.get(resolution, args.batch_default)) # increasing generator learningrate as it falls behind discriminator. After adding more informative loss to # discriminator generic_utils.adjust_lr(g_optimizer, args.lr.get(resolution, 0.001), args.use_styled_conv_stylegan2) pbar = tqdm(range(3_000_000)) generic_utils.requires_grad(generator, False) generic_utils.requires_grad(discriminator_flm, True) generic_utils.adjust_lr(d_optimizer_flm, args.lr.get(resolution, 0.001), args.use_styled_conv_stylegan2) disc_loss_val = 0 gen_loss_val = 0 grad_loss_val = 0 max_step = int(math.log2(args.max_size)) - 2 final_progress = False for i in pbar: generic_utils.requires_grad(discriminator_flm, True) discriminator_flm.zero_grad() alpha = min(1, 1 / args.phase * (used_samples + 1)) if (resolution == args.init_size and args.ckpt is None) or final_progress: alpha = 1 # Switching resolution code if used_samples > args.phase * 2: used_samples = 0 step += 1 if step > max_step: step = max_step final_progress = True else: alpha = 0 resolution = 4 * 2**step time_now = datetime.now() print(f'{time_now} : Resolution is : ' + str(resolution)) resolutions = [4 * 2**_step for _step in range(step + 1)] loader = sample_data(dataset, args.batch.get(resolution, args.batch_default), resolutions, debug=args.debug) data_loader = iter(loader) generic_utils.adjust_lr(g_optimizer, args.lr.get(resolution, 0.001), args.use_styled_conv_stylegan2) try: real_image, flm_rndr, flm_lbls, input_indices = next( data_loader) # Real image sin different scales except (OSError, StopIteration): data_loader = iter(loader) real_image, flm_rndr, flm_lbls, input_indices = next(data_loader) if not args.rendered_flame_as_condition and not args.normal_maps_as_cond: flm_rndr = flm_lbls dataset.accumulate_batches_of_flm(flm_lbls[0], pose) else: dataset.accumulate_batches_of_flm(flm_rndr[0], pose) real_image = real_image.cuda() real_image_list = [ fast_image_reshape(real_image, resolutions[-1], resolutions[-1]) ] flm_lbls = flm_lbls[0].cuda() flm_rndr = flm_rndr[0].cuda() input_indices = input_indices.cuda() b_size = args.batch.get(resolution, args.batch_default) used_samples += b_size for real_image in real_image_list: real_image.requires_grad = True real_img_condition = flm_rndr real_img_condition.requires_grad = True # import ipdb; ipdb.set_trace() real_scores_flm, _ = discriminator_flm(real_image_list, condition=real_img_condition, step=step, alpha=alpha) real_Dloss_flm = torch.nn.functional.softplus(-real_scores_flm).mean() if ( i + 1 ) % 16 == 0: # To save time do only every 16th iteration. otherwise 17 sec per itr # grad_penalty_flm = losses.grad_penalty_loss(real_image_list + [real_img_condition], real_scores_flm, # step=None) grad_penalty_flm = losses.grad_penalty_loss(real_image_list, real_scores_flm, step=None) real_Dloss_flm += grad_penalty_flm.mean() if args.rendered_flame_as_condition or args.normal_maps_as_cond: gen_in1 = flm_rndr.clone() else: gen_in1 = flm_lbls.clone() # import ipdb; ipdb.set_trace() fake_image_list = generator(gen_in1, pose, step=step, alpha=alpha, input_indices=input_indices) cond_fake_imgs = gen_in1 fake_image_list[0] = fake_image_list[0].detach() if args.shfld_cond_as_neg_smpl: # fake_image_list[0] = torch.cat((fake_image_list[0], real_image_list[0]), dim=0) fake_image_list[0] = fake_image_list[0].repeat((2, 1, 1, 1)) shuffle_indices = generic_utils.get_unique_shuffle_indices( real_img_condition.shape[0]) cond_discrim_fake_imgs = torch.cat( (cond_fake_imgs, real_img_condition[shuffle_indices, :]), dim=0) else: cond_discrim_fake_imgs = cond_fake_imgs fake_scores_flm, _ = discriminator_flm( fake_image_list, condition=cond_discrim_fake_imgs, step=step, alpha=alpha) fake_Dloss = F.softplus(fake_scores_flm).mean() (real_Dloss_flm + fake_Dloss).backward() disc_loss_val = fake_Dloss + real_Dloss_flm disc_loss_val = (disc_loss_val + fake_Dloss).item() d_optimizer_flm.step() ################################## Training Generator ####################################### if n_critic >= 1: # letting generator to be trained more if (i + 1) % n_critic == 0: gen_itr_count = 1 else: gen_itr_count = 0 else: gen_itr_count = int(1 / n_critic) generic_utils.requires_grad(generator, True) generic_utils.requires_grad(discriminator_flm, False) for gen_trn_itr in range(gen_itr_count): generator.zero_grad() # import ipdb; ipdb.set_trace() fake_image_list = generator(gen_in1, pose, step=step, alpha=alpha, input_indices=input_indices) # cond_fake_imgs.detach() is necessary because it will try and propaget gradients and will retain # cleared buffers from training on real data and cause double call of backward! predict_flm, _ = discriminator_flm( fake_image_list, condition=cond_fake_imgs.detach(), step=step, alpha=alpha) fake_gen_loss = F.softplus(-predict_flm).mean() if args.gen_reg_type.upper() == 'PATH_LEN_REG': gen_flm_weight = 2 fake_gen_loss += gen_flm_weight * pl_reg.path_length_reg( generator, step=step, alpha=alpha, input_indices=input_indices) elif args.gen_reg_type.upper() == 'DIRECT_GRAD_REG': # Changes in flame input should cause as small output change as possible gen_flm_weight = 1e-8 * 8 fake_gen_loss += gen_flm_weight * losses.grad_penalty_loss( inputs=[ gen_in1, ], outs=torch.pow(fake_image_list[-1], 2), step=None) # embeddign regularizatio loss if args.embedding_vocab_size != 1: embedding_reg_loss = args.embedding_reg_weight * losses.l2_reg( generator.module.z_to_w) fake_gen_loss += embedding_reg_loss # interpolation loss. Texture must stay same even when the face is moved with different flame parameters if args.apply_texture_space_interpolation_loss: # This is the only reason why FLAME labels are necessary in train time flm_intrp_batch = flm_lbls[:-1, :159] + \ np.random.uniform(0, 1)*(flm_lbls[1:, :159] - flm_lbls[:-1, :159]) # During interpolation light and texture code should stay constant. Don in the loss function flm_intrp_batch = torch.cat( (flm_intrp_batch, flm_lbls[:-1, 159:]), axis=-1) # import ipdb; ipdb.set_trace() interp_loss = interp_tex_loss.tex_sp_intrp_loss( dataset.un_normalize_flame(flm_intrp_batch), generator, step=step, alpha=alpha, max_ids=args.embedding_vocab_size, normal_maps_as_cond=args.normal_maps_as_cond, use_posed_constant_input=args.use_posed_constant_input, rendered_flame_as_condition=args. rendered_flame_as_condition) if args.adaptive_interp_loss: interp_loss *= 0.25 * fake_gen_loss.detach( ) / interp_loss.detach() fake_gen_loss += interp_loss # import ipdb; ipdb.set_trace() fake_gen_loss.backward() g_optimizer.step() if gen_loss_val is None: gen_loss_val = fake_gen_loss.item() else: gen_loss_val = gen_loss_val * run_avg_rate + fake_gen_loss.item( ) * (1 - run_avg_rate) # decay factor copied from STG2 generic_utils.accumulate(g_running, generator, decay=0.5**(32 / (10 * 1000))) generic_utils.requires_grad(generator, False) if (i + 1) % 1000 == 0: md_chk_pt_name = f'{cnst.output_root}checkpoint/{str(args.run_id)}/{str(i + 1).zfill(6)}_{alpha}.model' chk_pt_dict = { 'generator_running': g_running.state_dict(), 'generator': generator.state_dict(), 'g_optimizer': g_optimizer.state_dict(), 'discriminator_flm': discriminator_flm.state_dict(), 'd_optimizer_flm': d_optimizer_flm.state_dict() } torch.save(chk_pt_dict, md_chk_pt_name) np.savez(md_chk_pt_name.replace('.model', '.npz'), step=step, used_sampless=used_sampless, alpha=alpha, resolution=resolution) if (i + 1) % 500 == 0: # fid_computation flame_parmas, input_indices, pose = dataset.get_10k_flame_params() # import ipdb; ipdb.set_trace() image_tensor = generic_utils.get_images_from_flame_params( flame_parmas, pose, g_running, step=step, alpha=alpha, input_indices=input_indices) # import ipdb; ipdb.set_trace() fid = fid_computer.get_fid(image_tensor) state_msg = ( f'Size: {4 * 2 ** step}; G: {gen_loss_val:.3f}; D: {disc_loss_val:.3f};' f' fid: {fid:.0f}') if args.embedding_vocab_size != 1: state_msg += f', embd_reg_l: {embedding_reg_loss:.3f}; ' if args.apply_texture_space_interpolation_loss: if interp_loss_run_avg is None: interp_loss_run_avg = interp_loss.item() else: interp_loss_run_avg = interp_loss_run_avg * run_avg_rate + interp_loss.item( ) * (1 - run_avg_rate) state_msg += f', interp_l: {interp_loss_run_avg:.3f}; ' pbar.set_description(state_msg) if (i + 1) % 500 == 0: if flame_param_est is None: condition_parmas = torch.from_numpy(flame_parmas[:50]).cuda() if pose is not None: pose_for_saving = torch.from_numpy(pose[:50]).cuda() else: pose_for_saving = None inpt_idxs = torch.from_numpy(input_indices[:50]).cuda() viz_saver.set_flame_params(pose_for_saving, condition_parmas, inpt_idxs) flame_param_est = 0 viz_saver.save_samples(i, model=g_running, step=step, alpha=alpha, resolution=resolution, fid=fid, run_id=args.run_id)