Exemple #1
0
def example_render_batch3(pic_names: list, tf_bfm: TfMorphableModel,
                          n_tex_para: int, save_to_folder: str,
                          resolution: int):
    batch_size = len(pic_names)

    images_orignal = load_images(pic_names, '/opt/project/examples/Data/80k/')

    shape_param_batch, exp_param_batch, pose_param_batch = load_params_80k(
        pic_names=pic_names)
    shape_param = tf.squeeze(shape_param_batch)
    exp_param = tf.squeeze(exp_param_batch)
    pose_param = tf.squeeze(pose_param_batch)
    pose_param = tf.concat([
        pose_param[:, :-1],
        tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32),
        pose_param[:, -1:]
    ],
                           axis=1)
    lm = tf_bfm.get_landmarks(shape_param,
                              exp_param,
                              pose_param,
                              batch_size,
                              450,
                              is_2d=True,
                              is_plot=True)

    images_rendered = render_batch(
        pose_param=pose_param,
        shape_param=shape_param,
        exp_param=exp_param,
        tex_param=tf.constant(0.0,
                              shape=(len(pic_names), n_tex_para),
                              dtype=tf.float32),
        color_param=None,
        illum_param=None,
        frame_height=450,
        frame_width=450,
        tf_bfm=tf_bfm,
        batch_size=batch_size).numpy().astype(np.uint8)

    for i, pic_name in enumerate(pic_names):
        fig = plt.figure()
        ax = fig.add_subplot(1, 2, 1)
        plot_image_w_lm(ax, resolution, images_orignal[i], lm[i])
        ax = fig.add_subplot(1, 2, 2)
        plot_image_w_lm(ax, resolution, images_rendered[i], lm[i])
        plt.savefig(os.path.join(save_to_folder, pic_name))
Exemple #2
0
def example_render_batch2(pic_names: list, tf_bfm: TfMorphableModel, save_to_folder: str, n_tex_para:int):
    batch_size = len(pic_names)

    images_orignal = load_images(pic_names, '/opt/project/examples/Data/300W_LP/')

    shape_param_batch, exp_param_batch, tex_param_batch, color_param_batch, illum_param_batch, pose_param_batch, lm_batch = \
        load_params(pic_names=pic_names, n_tex_para=n_tex_para)

    # pose_param: [batch, n_pose_param]
    # shape_param: [batch, n_shape_para]
    # exp_param:   [batch, n_exp_para]
    # tex_param: [batch, n_tex_para]
    # color_param: [batch, n_color_para]
    # illum_param: [batch, n_illum_para]

    shape_param_batch = tf.squeeze(shape_param_batch)
    exp_param_batch = tf.squeeze(exp_param_batch)
    tex_param_batch = tf.squeeze(tex_param_batch)
    color_param_batch = tf.squeeze(color_param_batch)
    illum_param_batch = tf.squeeze(illum_param_batch)
    pose_param_batch = tf.squeeze(pose_param_batch)
    lm_rended = tf_bfm.get_landmarks(shape_param_batch, exp_param_batch, pose_param_batch, batch_size, 450, is_2d=True, is_plot=True)

    images_rendered = render_batch(
        pose_param=pose_param_batch,
        shape_param=shape_param_batch,
        exp_param=exp_param_batch,
        tex_param=tex_param_batch,
        color_param=color_param_batch,
        illum_param=illum_param_batch,
        frame_height=450,
        frame_width=450,
        tf_bfm=tf_bfm,
        batch_size=batch_size
    ).numpy().astype(np.uint8)

    for i, pic_name in enumerate(pic_names):
        fig = plt.figure()
        ax = fig.add_subplot(1, 2, 1)
        plot_image_w_lm(ax, 450, images_orignal[i], lm_batch[i])
        ax = fig.add_subplot(1, 2, 2)
        plot_image_w_lm(ax, 450, images_rendered[i], lm_rended[i])
        plt.savefig(os.path.join(save_to_folder, pic_name))
Exemple #3
0
from example_utils import load_images, load_params
from sandbox.landmarks.check_get_landmarks2 import save_landmarks
from tf_3dmm.morphable_model.morphable_model import TfMorphableModel


if __name__ == '__main__':
    n_tex_para = 40
    bfm = TfMorphableModel(model_path='/opt/project/examples/Data/BFM/Out/BFM.mat', n_tex_para=n_tex_para)
    output_folder = '/opt/project/output/landmarks/landmark3'
    pic_names = ['image00002', 'IBUG_image_014_01_2', 'AFW_134212_1_0', 'IBUG_image_008_1_0']
    # pic_names = ['IBUG_image_008_1_0']
    batch_size = len(pic_names)
    images = load_images(pic_names, '/opt/project/examples/Data')
    resolution = 450
    shape_param, exp_param, _, _, _, pose_param = load_params(pic_names=pic_names, n_tex_para=n_tex_para,
                                                              data_folder='/opt/project/examples/Data/')
    landmarks = bfm.get_landmarks(
        shape_param=shape_param,
        exp_param=exp_param,
        pose_param=pose_param,
        batch_size=batch_size,
        resolution=resolution,
        is_2d=True,
        is_plot=True
    )
    save_landmarks(images=images, landmarks=landmarks, output_folder=output_folder, resolution=resolution)
def display(tfrecord_dir,
            bfm_path,
            exp_path,
            param_mean_std_path,
            image_size,
            num_images=5,
            n_tex_para=40,
            n_shape_para=100):
    print('Loading sdataset %s' % tfrecord_dir)

    batch_size = 4
    dset = dataset.TFRecordDatasetSupervised(tfrecord_dir=tfrecord_dir,
                                             batch_size=batch_size,
                                             repeat=False,
                                             shuffle_mb=0)
    print('Loading BFM model')
    bfm = TfMorphableModel(model_path=bfm_path,
                           exp_path=exp_path,
                           n_shape_para=n_shape_para,
                           n_tex_para=n_tex_para)

    idx = 0
    filename = '/opt/project/output/verify_dataset/supervised-80k/20200717/image_batch_{0}_indx_{1}.jpg'
    unnormalize_labels = fn_unnormalize_80k_labels(
        param_mean_std_path=param_mean_std_path, image_size=image_size)
    while idx < num_images:
        try:
            image_tensor, labels_tensor = dset.get_minibatch_tf()
        except tf.errors.OutOfRangeError:
            break

        # render images using labels
        pose_para, shape_para, exp_para, _, _, _ = split_80k_labels(
            labels_tensor)
        pose_para, shape_para, exp_para, _, _, _ = unnormalize_labels(
            batch_size, pose_para, shape_para, exp_para, None, None, None)
        # add 0 to t3d z axis
        # 80k dataset only have x, y translation
        pose_para = tf.concat([
            pose_para[:, :-1],
            tf.constant(0.0, shape=(batch_size, 1), dtype=tf.float32),
            pose_para[:, -1:]
        ],
                              axis=1)

        landmarks = bfm.get_landmarks(shape_para,
                                      exp_para,
                                      pose_para,
                                      batch_size,
                                      image_size,
                                      is_2d=True,
                                      is_plot=True)
        image_rendered = render_batch(
            pose_param=pose_para,
            shape_param=shape_para,
            exp_param=exp_para,
            tex_param=tf.constant(0.0,
                                  shape=(batch_size, n_tex_para),
                                  dtype=tf.float32),
            color_param=None,
            illum_param=None,
            frame_height=image_size,
            frame_width=image_size,
            tf_bfm=bfm,
            batch_size=batch_size).numpy().astype(np.uint8)

        for i in range(batch_size):
            fig = plt.figure()
            # input image
            ax = fig.add_subplot(1, 2, 1)
            ax.imshow(image_tensor[i].numpy().astype(np.uint8))
            ax.plot(landmarks[i, 0, 0:17],
                    landmarks[i, 1, 0:17],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 17:22],
                    landmarks[i, 1, 17:22],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 22:27],
                    landmarks[i, 1, 22:27],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 27:31],
                    landmarks[i, 1, 27:31],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 31:36],
                    landmarks[i, 1, 31:36],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 36:42],
                    landmarks[i, 1, 36:42],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 42:48],
                    landmarks[i, 1, 42:48],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 48:60],
                    landmarks[i, 1, 48:60],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)
            ax.plot(landmarks[i, 0, 60:68],
                    landmarks[i, 1, 60:68],
                    marker='o',
                    markersize=2,
                    linestyle='-',
                    color='w',
                    lw=2)

            ax2 = fig.add_subplot(1, 2, 2)
            ax2.imshow(image_rendered[i])
            ax2.plot(landmarks[i, 0, 0:17],
                     landmarks[i, 1, 0:17],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 17:22],
                     landmarks[i, 1, 17:22],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 22:27],
                     landmarks[i, 1, 22:27],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 27:31],
                     landmarks[i, 1, 27:31],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 31:36],
                     landmarks[i, 1, 31:36],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 36:42],
                     landmarks[i, 1, 36:42],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 42:48],
                     landmarks[i, 1, 42:48],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 48:60],
                     landmarks[i, 1, 48:60],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)
            ax2.plot(landmarks[i, 0, 60:68],
                     landmarks[i, 1, 60:68],
                     marker='o',
                     markersize=2,
                     linestyle='-',
                     color='w',
                     lw=2)

            plt.savefig(filename.format(idx, i))

        idx += 1