def get_loss(self, gt_images, gt_mask, est_params, batch_size): # split params and unnormalize params _, est_lm, est_pp, est_shape, est_exp, est_color, est_illum, est_tex = split_300W_LP_labels( est_params) _, est_lm, est_pp, est_shape, est_exp, est_color, est_illum, est_tex = self.unnormalize_labels( self.train_batch_size, None, est_lm, est_pp, est_shape, est_exp, est_color, est_illum, est_tex) # geo loss, render with estimated geo parameters and ground truth pose est_images = render_batch(pose_param=est_pp, shape_param=est_shape, exp_param=est_exp, tex_param=est_tex, color_param=est_color, illum_param=est_illum, frame_width=self.resolution, frame_height=self.resolution, tf_bfm=self.bfm, batch_size=batch_size) gt_images = tf.where(gt_mask == 255, gt_images, 0) est_images = tf.where(gt_mask == 255, est_images, 0) loss = tf.sqrt(tf.reduce_mean(tf.square(gt_images - est_images))) return loss / self.strategy.num_replicas_in_sync
def get_loss(self, gt_params, gt_images, est_params, batch_size): # gt_params have only landmarks gt_lm = tf.reshape(gt_params, shape=(-1, 2, 68)) * self.resolution est_pp, est_shape, est_exp, est_color, est_illum, est_tex = split_ffhq_labels( est_params) # regularization loss loss_reg = tf.sqrt(tf.reduce_mean(tf.square(est_pp))) loss_reg += tf.sqrt(tf.reduce_mean(tf.square(est_shape))) loss_reg += tf.sqrt(tf.reduce_mean(tf.square(est_exp))) loss_reg += tf.sqrt(tf.reduce_mean(tf.square(est_color))) loss_reg += tf.sqrt(tf.reduce_mean(tf.square(est_illum))) loss_reg += tf.sqrt(tf.reduce_mean(tf.square(est_tex))) est_pp, est_shape, est_exp, est_color, est_illum, est_tex = self.unnormalize_labels( batch_size, est_pp, est_shape, est_exp, est_color, est_illum, est_tex) # add 0 to t3d z axis # only have x, y translation est_pp = tf.concat([ est_pp[:, :-1], tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32), est_pp[:, -1:] ], axis=1) # image rendered with ground truth shape param, loss on texture/color est_images = render_batch(pose_param=est_pp, shape_param=est_shape, exp_param=est_exp, tex_param=est_tex, color_param=est_color, illum_param=est_illum, frame_width=self.resolution, frame_height=self.resolution, tf_bfm=self.bfm, batch_size=batch_size) gt_images = tf.cast(tf.where(est_images > 0, gt_images, 0), tf.float32) loss_img = tf.sqrt(tf.reduce_mean(tf.square(est_images - gt_images))) # landmark loss est_lm = self.bfm.get_landmarks(shape_param=est_shape, exp_param=est_exp, pose_param=est_pp, batch_size=batch_size, resolution=self.resolution, is_2d=True, is_plot=True) loss_lms = tf.sqrt(tf.reduce_mean(tf.square(gt_lm - est_lm))) return loss_img / self.strategy.num_replicas_in_sync, 50.0 * loss_lms / self.strategy.num_replicas_in_sync, loss_reg / self.strategy.num_replicas_in_sync
def display(tfrecord_dir, bfm_path, param_mean_std_path, image_size, num_images=5, n_tex_para=40): print('Loading sdataset %s' % tfrecord_dir) batch_size = 4 dset = dataset.TFRecordDatasetSupervised(tfrecord_dir=tfrecord_dir, batch_size=batch_size, repeat=False, shuffle_mb=0) print('Loading BFM model') bfm = TfMorphableModel(model_path=bfm_path, n_tex_para=n_tex_para) idx = 0 filename = '/opt/project/output/verify_dataset/supervised/20200525/image_batch_{0}_indx_{1}.jpg' unnormalize_labels = fn_unnormalize_300W_LP_labels( param_mean_std_path=param_mean_std_path, image_size=image_size) while idx < num_images: try: image_tensor, labels_tensor = dset.get_minibatch_tf() except tf.errors.OutOfRangeError: break # render images using labels roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para = split_300W_LP_labels( labels_tensor) roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para = unnormalize_labels( batch_size, roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para) image_rendered = render_batch(pose_param=pose_para, shape_param=shape_para, exp_param=exp_para, tex_param=tex_para, color_param=color_para, illum_param=illum_para, frame_height=image_size, frame_width=image_size, tf_bfm=bfm, batch_size=batch_size).numpy().astype( np.uint8) for i in range(batch_size): images = np.concatenate( (image_tensor[i].numpy().astype(np.uint8), image_rendered[i]), axis=0) # images = image_rendered[i] imageio.imsave(filename.format(idx, i), images) idx += 1 print('\nDisplayed %d images' % idx)
def example_render_batch3(pic_names: list, tf_bfm: TfMorphableModel, n_tex_para: int, save_to_folder: str, resolution: int): batch_size = len(pic_names) images_orignal = load_images(pic_names, '/opt/project/examples/Data/80k/') shape_param_batch, exp_param_batch, pose_param_batch = load_params_80k( pic_names=pic_names) shape_param = tf.squeeze(shape_param_batch) exp_param = tf.squeeze(exp_param_batch) pose_param = tf.squeeze(pose_param_batch) pose_param = tf.concat([ pose_param[:, :-1], tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32), pose_param[:, -1:] ], axis=1) lm = tf_bfm.get_landmarks(shape_param, exp_param, pose_param, batch_size, 450, is_2d=True, is_plot=True) images_rendered = render_batch( pose_param=pose_param, shape_param=shape_param, exp_param=exp_param, tex_param=tf.constant(0.0, shape=(len(pic_names), n_tex_para), dtype=tf.float32), color_param=None, illum_param=None, frame_height=450, frame_width=450, tf_bfm=tf_bfm, batch_size=batch_size).numpy().astype(np.uint8) for i, pic_name in enumerate(pic_names): fig = plt.figure() ax = fig.add_subplot(1, 2, 1) plot_image_w_lm(ax, resolution, images_orignal[i], lm[i]) ax = fig.add_subplot(1, 2, 2) plot_image_w_lm(ax, resolution, images_rendered[i], lm[i]) plt.savefig(os.path.join(save_to_folder, pic_name))
def example_render_batch2(pic_names: list, tf_bfm: TfMorphableModel, save_to_folder: str, n_tex_para:int): batch_size = len(pic_names) images_orignal = load_images(pic_names, '/opt/project/examples/Data/300W_LP/') shape_param_batch, exp_param_batch, tex_param_batch, color_param_batch, illum_param_batch, pose_param_batch, lm_batch = \ load_params(pic_names=pic_names, n_tex_para=n_tex_para) # pose_param: [batch, n_pose_param] # shape_param: [batch, n_shape_para] # exp_param: [batch, n_exp_para] # tex_param: [batch, n_tex_para] # color_param: [batch, n_color_para] # illum_param: [batch, n_illum_para] shape_param_batch = tf.squeeze(shape_param_batch) exp_param_batch = tf.squeeze(exp_param_batch) tex_param_batch = tf.squeeze(tex_param_batch) color_param_batch = tf.squeeze(color_param_batch) illum_param_batch = tf.squeeze(illum_param_batch) pose_param_batch = tf.squeeze(pose_param_batch) lm_rended = tf_bfm.get_landmarks(shape_param_batch, exp_param_batch, pose_param_batch, batch_size, 450, is_2d=True, is_plot=True) images_rendered = render_batch( pose_param=pose_param_batch, shape_param=shape_param_batch, exp_param=exp_param_batch, tex_param=tex_param_batch, color_param=color_param_batch, illum_param=illum_param_batch, frame_height=450, frame_width=450, tf_bfm=tf_bfm, batch_size=batch_size ).numpy().astype(np.uint8) for i, pic_name in enumerate(pic_names): fig = plt.figure() ax = fig.add_subplot(1, 2, 1) plot_image_w_lm(ax, 450, images_orignal[i], lm_batch[i]) ax = fig.add_subplot(1, 2, 2) plot_image_w_lm(ax, 450, images_rendered[i], lm_rended[i]) plt.savefig(os.path.join(save_to_folder, pic_name))
def inference_and_render_images(images, images_names, model, bfm, unnormalize_labels, rendered_filename_tmp): batch_size = len(images) reals = tf.convert_to_tensor(images, dtype=tf.uint8) reals = process_reals_supervised(x=reals, mirror_augment=False, drange_data=[0, 255], drange_net=[-1, 1]) est_params = model(reals) pose_para, shape_para, exp_para, color_para, illum_para, tex_para = split_ffhq_labels(est_params) pose_para, shape_para, exp_para, color_para, illum_para, tex_para = unnormalize_labels( batch_size, pose_para, shape_para, exp_para, color_para, illum_para, tex_para) # add 0 to t3d z axis # 80k dataset only have x, y translation pose_para = tf.concat( [pose_para[:, :-1], tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32), pose_para[:, -1:]], axis=1) landmarks = bfm.get_landmarks(shape_para, exp_para, pose_para, batch_size, image_size, is_2d=True, is_plot=True) image_rendered = render_batch( pose_param=pose_para, shape_param=shape_para, exp_param=exp_para, tex_param=tex_para, color_param=color_para, illum_param=illum_para, frame_height=image_size, frame_width=image_size, tf_bfm=bfm, batch_size=batch_size ).numpy().astype(np.uint8) for i in range(batch_size): # input image pic_name = '.'.join(images_names[i].split('.')[:-1]) img_rgb = cv2.cvtColor(images[i], cv2.COLOR_BGR2RGB) img_rgb = add_landmarks(img_rgb, landmarks[i]) img_rendered_rgb = cv2.cvtColor(image_rendered[i], cv2.COLOR_BGR2RGB) img_rendered_rgb = add_landmarks(img_rendered_rgb, landmarks[i]) img_all = np.concatenate((img_rgb, img_rendered_rgb), axis=1) cv2.imwrite(rendered_filename_tmp.format(pic_name), img_all)
tf.constant(mat_data['Color_Para'], dtype=tf.float32)) illum_param_batch.append( tf.constant(mat_data['Illum_Para'], dtype=tf.float32)) pose_param_batch.append( tf.constant(mat_data['Pose_Para'], dtype=tf.float32)) shape_param_batch = tf.stack(shape_param_batch, axis=0) exp_param_batch = tf.stack(exp_param_batch, axis=0) tex_param_batch = tf.stack(tex_param_batch, axis=0) color_param_batch = tf.stack(color_param_batch, axis=0) illum_param_batch = tf.stack(illum_param_batch, axis=0) pose_param_batch = tf.stack(pose_param_batch, axis=0) return shape_param_batch, exp_param_batch, tex_param_batch, color_param_batch, illum_param_batch, pose_param_batch shape_param_batch, exp_param_batch, tex_param_batch, color_param_batch, illum_param_batch, pose_param_batch = \ my_load_params(pic_names=pic_names, n_tex_para=n_tex_para) i = 0 while True: images = render_batch(pose_param=pose_param_batch, shape_param=shape_param_batch, exp_param=exp_param_batch, tex_param=tex_param_batch, color_param=color_param_batch, illum_param=illum_param_batch, frame_height=450, frame_width=450, tf_bfm=tf_bfm, batch_size=batch_size) i += 1 print(i)
def display(tfrecord_dir, bfm_path, exp_path, param_mean_std_path, image_size, num_images=5, n_tex_para=40, n_shape_para=100): print('Loading sdataset %s' % tfrecord_dir) batch_size = 4 dset = dataset.TFRecordDatasetSupervised(tfrecord_dir=tfrecord_dir, batch_size=batch_size, repeat=False, shuffle_mb=0) print('Loading BFM model') bfm = TfMorphableModel(model_path=bfm_path, exp_path=exp_path, n_shape_para=n_shape_para, n_tex_para=n_tex_para) idx = 0 filename = '/opt/project/output/verify_dataset/supervised-80k/20200717/image_batch_{0}_indx_{1}.jpg' unnormalize_labels = fn_unnormalize_80k_labels( param_mean_std_path=param_mean_std_path, image_size=image_size) while idx < num_images: try: image_tensor, labels_tensor = dset.get_minibatch_tf() except tf.errors.OutOfRangeError: break # render images using labels pose_para, shape_para, exp_para, _, _, _ = split_80k_labels( labels_tensor) pose_para, shape_para, exp_para, _, _, _ = unnormalize_labels( batch_size, pose_para, shape_para, exp_para, None, None, None) # add 0 to t3d z axis # 80k dataset only have x, y translation pose_para = tf.concat([ pose_para[:, :-1], tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32), pose_para[:, -1:] ], axis=1) landmarks = bfm.get_landmarks(shape_para, exp_para, pose_para, batch_size, image_size, is_2d=True, is_plot=True) image_rendered = render_batch( pose_param=pose_para, shape_param=shape_para, exp_param=exp_para, tex_param=tf.constant(0.0, shape=(batch_size, n_tex_para), dtype=tf.float32), color_param=None, illum_param=None, frame_height=image_size, frame_width=image_size, tf_bfm=bfm, batch_size=batch_size).numpy().astype(np.uint8) for i in range(batch_size): fig = plt.figure() # input image ax = fig.add_subplot(1, 2, 1) ax.imshow(image_tensor[i].numpy().astype(np.uint8)) ax.plot(landmarks[i, 0, 0:17], landmarks[i, 1, 0:17], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 17:22], landmarks[i, 1, 17:22], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 22:27], landmarks[i, 1, 22:27], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 27:31], landmarks[i, 1, 27:31], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 31:36], landmarks[i, 1, 31:36], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 36:42], landmarks[i, 1, 36:42], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 42:48], landmarks[i, 1, 42:48], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 48:60], landmarks[i, 1, 48:60], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 60:68], landmarks[i, 1, 60:68], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2 = fig.add_subplot(1, 2, 2) ax2.imshow(image_rendered[i]) ax2.plot(landmarks[i, 0, 0:17], landmarks[i, 1, 0:17], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 17:22], landmarks[i, 1, 17:22], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 22:27], landmarks[i, 1, 22:27], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 27:31], landmarks[i, 1, 27:31], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 31:36], landmarks[i, 1, 31:36], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 36:42], landmarks[i, 1, 36:42], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 42:48], landmarks[i, 1, 42:48], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 48:60], landmarks[i, 1, 48:60], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 60:68], landmarks[i, 1, 60:68], marker='o', markersize=2, linestyle='-', color='w', lw=2) plt.savefig(filename.format(idx, i)) idx += 1
def display(tfrecord_dir, bfm_path, param_mean_std_path, image_size, num_images=5, n_tex_para=40): print('Loading sdataset %s' % tfrecord_dir) batch_size = 4 dset = dataset.TFRecordDatasetSupervised(tfrecord_dir=tfrecord_dir, batch_size=batch_size, repeat=False, shuffle_mb=0) print('Loading BFM model') bfm = TfMorphableModel(model_path=bfm_path, n_tex_para=n_tex_para) idx = 0 filename = '/opt/project/output/verify_dataset/supervised/20200525/image_batch_{0}_indx_{1}.jpg' unnormalize_labels = fn_unnormalize_300W_LP_labels( param_mean_std_path=param_mean_std_path, image_size=image_size) while idx < num_images: try: image_tensor, labels_tensor = dset.get_minibatch_tf() except tf.errors.OutOfRangeError: break # render images using labels roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para = split_300W_LP_labels( labels_tensor) roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para = unnormalize_labels( batch_size, roi, landmarks, pose_para, shape_para, exp_para, color_para, illum_para, tex_para) image_rendered = render_batch(pose_param=pose_para, shape_param=shape_para, exp_param=exp_para, tex_param=tex_para, color_param=color_para, illum_param=illum_para, frame_height=image_size, frame_width=image_size, tf_bfm=bfm, batch_size=batch_size).numpy().astype( np.uint8) for i in range(batch_size): fig = plt.figure() # input image ax = fig.add_subplot(1, 2, 1) ax.imshow(image_tensor[i].numpy().astype(np.uint8)) ax.plot(landmarks[i, 0, 0:17], landmarks[i, 1, 0:17], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 17:22], landmarks[i, 1, 17:22], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 22:27], landmarks[i, 1, 22:27], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 27:31], landmarks[i, 1, 27:31], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 31:36], landmarks[i, 1, 31:36], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 36:42], landmarks[i, 1, 36:42], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 42:48], landmarks[i, 1, 42:48], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 48:60], landmarks[i, 1, 48:60], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax.plot(landmarks[i, 0, 60:68], landmarks[i, 1, 60:68], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2 = fig.add_subplot(1, 2, 2) ax2.imshow(image_rendered[i]) ax2.plot(landmarks[i, 0, 0:17], landmarks[i, 1, 0:17], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 17:22], landmarks[i, 1, 17:22], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 22:27], landmarks[i, 1, 22:27], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 27:31], landmarks[i, 1, 27:31], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 31:36], landmarks[i, 1, 31:36], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 36:42], landmarks[i, 1, 36:42], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 42:48], landmarks[i, 1, 42:48], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 48:60], landmarks[i, 1, 48:60], marker='o', markersize=2, linestyle='-', color='w', lw=2) ax2.plot(landmarks[i, 0, 60:68], landmarks[i, 1, 60:68], marker='o', markersize=2, linestyle='-', color='w', lw=2) plt.savefig(filename.format(idx, i)) idx += 1 print('\nDisplayed %d images' % idx)