def get_gif_from_list_of_params(generator, flame_params, step, alpha, noise, overlay_landmarks, flame_std, flame_mean, overlay_visualizer, rendered_flame_as_condition, use_posed_constant_input, normal_maps_as_cond, camera_params): # cam_t = np.array([0., 0., 2.5]) # camera_params = camera_dynamic((224, 224), cam_t) if overlay_visualizer is None: overlay_visualizer = OverLayViz() fixed_embeddings = torch.ones(flame_params.shape[0], dtype=torch.long, device='cuda')*13 # print(generator.module.get_embddings()[fixed_embeddings]) flame_params_unnorm = flame_params * flame_std + flame_mean flame_params_unnorm = torch.from_numpy(flame_params_unnorm).cuda() normal_map_img, _, _, _, rend_imgs = \ overlay_visualizer.get_rendered_mesh(flame_params=(flame_params_unnorm[:, SHAPE_IDS[0]:SHAPE_IDS[1]], flame_params_unnorm[:, EXP_IDS[0]:EXP_IDS[1]], flame_params_unnorm[:, POSE_IDS[0]:POSE_IDS[1]], flame_params_unnorm[:, TRANS_IDS[0]:TRANS_IDS[1]]), camera_params=camera_params) rend_imgs = (rend_imgs/127.0 - 1) if use_posed_constant_input: pose = flame_params[:, constants.get_idx_list('GLOBAL_ROT')] else: pose = None if rendered_flame_as_condition: gen_in = rend_imgs else: gen_in = flame_params if normal_maps_as_cond: gen_in = torch.cat((rend_imgs, normal_map_img), dim=1) fake_images = generate_from_flame_sequence(generator, gen_in, pose, step, alpha, noise, input_indices=fixed_embeddings)[-1] fake_images = overlay_visualizer.range_normalize_images(fast_image_reshape(fake_images, height_out=rend_imgs.shape[2], width_out=rend_imgs.shape[3])) if rendered_flame_as_condition: fake_images = torch.cat([fake_images.cpu(), (rend_imgs.cpu() + 1)/2], dim=-1) if normal_maps_as_cond: fake_images = torch.cat([fake_images.cpu(), (normal_map_img.cpu() + 1) / 2], dim=-1) return fake_images
class VisualizationSaver(): def __init__(self, gen_i, gen_j, sampling_flame_labels, dataset, input_indices, overlay_mesh=False): self.gen_i = gen_i self.gen_j = gen_j self.sampling_flame_labels = sampling_flame_labels self.overlay_mesh = overlay_mesh self.overlay_visualizer = OverLayViz() self.dataset = dataset self.input_indices = input_indices self.cam_t = np.array([0., 0., 2.5]) def set_flame_params(self, pose, sampling_flame_labels, input_indices): self.pose = pose self.sampling_flame_labels = sampling_flame_labels self.input_indices = input_indices def save_samples(self, i, model, step, alpha, resolution, fid, run_id): images = [] # camera_params = camera_dynamic((resolution, resolution), self.cam_t) flength = 5000 cam_t = np.array([0., 0., 0]) camera_params = camera_ringnetpp((512, 512), trans=cam_t, focal=flength) with torch.no_grad(): for img_idx in range(self.gen_i): flame_param_this_batch = self.sampling_flame_labels[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] if self.pose is not None: pose_this_batch = self.pose[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] else: pose_this_batch = None idx_this_batch = self.input_indices[img_idx * self.gen_j:(img_idx + 1) * self.gen_j] img_tensor = model(flame_param_this_batch.clone(), pose_this_batch, step=step, alpha=alpha, input_indices=idx_this_batch)[-1] img_tensor = self.overlay_visualizer.range_normalize_images( dataset_loaders.fast_image_reshape(img_tensor, height_out=256, width_out=256, non_diff_allowed=True)) images.append(img_tensor.data.cpu()) torchvision.utils.save_image( torch.cat(images, 0), f'{cnst.output_root}sample/{str(run_id)}/{str(i + 1).zfill(6)}_res{resolution}x{resolution}_fid_{fid:.2f}.png', nrow=self.gen_i, normalize=True, range=(0, 1))