def __init__(self, max_images_in_batch): texture_data = np.load(cnst.flame_texture_space_dat_file, allow_pickle=True, encoding='latin1').item() self.flm_tex_dec = FlameTextureSpace(texture_data, data_un_normalizer=None).cuda() self.flame_visualizer = OverLayViz() self.face_region_only_mask = np.array( Image.open(cnst.face_region_mask_file))[:, :, 0:1].transpose( (2, 0, 1)) self.face_region_only_mask = ( torch.from_numpy(self.face_region_only_mask.astype('float32')) / 255.0)[None, ...] print('read face region mask') flength = 5000 cam_t = np.array([0., 0., 0]) self.camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength) self.max_num = max_images_in_batch - 1 # self.max_num = 5 self.pairs = [] for i in range(self.max_num): for j in range(i + 1, self.max_num): self.pairs.append((i, j)) self.pairs = np.array(self.pairs)
def save_samples(self, i, model, step, alpha, resolution, fid, run_id): images = [] # camera_params = camera_dynamic((resolution, resolution), self.cam_t) flength = 5000 cam_t = np.array([0., 0., 0]) camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength) with torch.no_grad(): for img_idx in range(self.gen_i): flame_param_this_batch = self.sampling_flame_labels[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] if self.pose is not None: pose_this_batch = self.pose[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] else: pose_this_batch = None idx_this_batch = self.input_indices[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] img_tensor = model(flame_param_this_batch.clone(), pose_this_batch, step=step, alpha=alpha, input_indices=idx_this_batch)[-1] img_tensor = self.overlay_visualizer.range_normalize_images( dataset_loaders.fast_image_reshape(img_tensor, height_out=256, width_out=256, non_diff_allowed=True)) images.append(img_tensor.data.cpu()) torchvision.utils.save_image( torch.cat(images, 0), f'{cnst.output_root}sample/{str(run_id)}/{str(i + 1).zfill(6)}_res{resolution}x{resolution}_fid_{fid:.2f}.png', nrow=self.gen_i, normalize=True, range=(0, 1))
# General settings save_images = True code_size = 236 use_inst_norm = True core_tensor_res = 4 resolution = 256 alpha = 1 step_max = int(np.log2(resolution) - 2) root_out_dir = f'{cnst.output_root}sample/' num_smpl_to_eval_on = 1000 use_styled_conv_stylegan2 = True flength = 5000 cam_t = np.array([0., 0., 0]) camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength) run_ids_1 = [ 29, ] # with sqrt(2) # run_ids_1 = [7, 24, 8, 3] # run_ids_1 = [7, 8, 3] settings_for_runs = \ {24: {'name': 'vector_cond', 'model_idx': '216000_1', 'normal_maps_as_cond': False, 'rendered_flame_as_condition': False, 'apply_sqrt2_fac_in_eq_lin': False}, 29: {'name': 'full_model', 'model_idx': '026000_1', 'normal_maps_as_cond': True, 'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': True}, 7: {'name': 'flm_rndr_tex_interp', 'model_idx': '488000_1', 'normal_maps_as_cond': False, 'rendered_flame_as_condition': True, 'apply_sqrt2_fac_in_eq_lin': False}, 3: {'name': 'norm_mp_tex_interp', 'model_idx': '040000_1', 'normal_maps_as_cond': True,
import tqdm from my_utils.flm_dynamic_fit_overlay import camera_ringnetpp import numpy as np from my_utils.visualize_flame_overlay import OverLayViz from my_utils.ringnet_overlay.util import tensor_vis_landmarks from PIL import Image import torch import matplotlib.pyplot as plt resolution = 512 flength = 5000 cam_t = np.array([0., 0., 0]) camera_params = camera_ringnetpp((resolution, resolution), trans=cam_t, focal=flength) overlay_visualizer = OverLayViz(full_neck=False, add_random_noise_to_background=False, background_img=None, texture_pattern_name='CHKR_BRD_FLT_TEETH', flame_version='FLAME_2020_revisited', image_size=resolution) dest_dir = '/is/cluster/work/pghosh/gif1.0/fitting_viz/flame_2020/' # dest_dir = '/is/cluster/work/pghosh/gif1.0/fitting_viz/flame_old/' for i in tqdm.tqdm(range(0, 300)): img_file = '/is/cluster/scratch/partha/face_gan_data/FFHQ/images1024x1024/' + str( i).zfill(5) + '.png' img_original = Image.open(img_file).resize((resolution, resolution)) images = torch.from_numpy(