Ejemplo n.º 1
0
def main(_):
    # load 3dmm
    basis3dmm = load_3dmm_basis(
        FLAGS.basis3dmm_path,
        FLAGS.uv_path,
        is_whole_uv=True,
    )

    if os.path.exists(FLAGS.output_dir) is False:
        os.makedirs(FLAGS.output_dir)
    """ build graph """
    front_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                 shape=[1, None, None, 3],
                                                 name="front_image")
    front_image_batch_resized = tf.image.resize(front_image_batch,
                                                (FLAGS.uv_size, FLAGS.uv_size))
    front_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                               shape=[1, None, None, 19],
                                               name="front_seg")
    front_proj_xyz_batch = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
        name="front_proj_xyz",
    )
    front_ver_norm_batch = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
        name="front_ver_norm",
    )

    base_uv_path = "../resources/base_tex.png"
    base_uv = Image.open(base_uv_path).resize((FLAGS.uv_size, FLAGS.uv_size))
    base_uv = np.asarray(base_uv, np.float32) / 255
    base_uv_batch = tf.constant(base_uv[np.newaxis, ...], name="base_uv")

    if FLAGS.is_mult_view:
        left_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                    shape=[1, None, None, 3],
                                                    name="left_image")
        left_image_batch_resized = tf.image.resize(
            left_image_batch, (FLAGS.uv_size, FLAGS.uv_size))
        left_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                  shape=[1, None, None, 19],
                                                  name="left_seg")
        left_proj_xyz_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="left_proj_xyz",
        )
        left_ver_norm_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="left_ver_norm",
        )

        right_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                     shape=[1, None, None, 3],
                                                     name="right_image")
        right_image_batch_resized = tf.image.resize(
            right_image_batch, (FLAGS.uv_size, FLAGS.uv_size))
        right_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                   shape=[1, None, None, 19],
                                                   name="right_seg")
        right_proj_xyz_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="right_proj_xyz",
        )
        right_ver_norm_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="right_ver_norm",
        )

        # read fixed blending masks for multiview
        front_mask_path = "../resources/mid_blend_mask.png"
        left_mask_path = "../resources/left_blend_mask.png"
        right_mask_path = "../resources/right_blend_mask.png"
        front_mask = (np.asarray(
            Image.open(front_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        left_mask = (np.asarray(
            Image.open(left_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        right_mask = (np.asarray(
            Image.open(right_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        mask_front_batch = tf.constant(front_mask[np.newaxis, ..., np.newaxis],
                                       tf.float32,
                                       name="mask_front")
        mask_left_batch = tf.constant(left_mask[np.newaxis, ..., np.newaxis],
                                      tf.float32,
                                      name="mask_left")
        mask_right_batch = tf.constant(right_mask[np.newaxis, ..., np.newaxis],
                                       tf.float32,
                                       name="mask_right")

    front_uv_batch, front_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
        front_image_batch_resized / 255.0,
        front_proj_xyz_batch * FLAGS.uv_size / 300,
        front_ver_norm_batch,
        basis3dmm,
        FLAGS.uv_size,
    )

    front_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
        front_seg_batch,
        front_proj_xyz_batch,
        front_ver_norm_batch,
        basis3dmm,
        FLAGS.uv_size,
    )

    if FLAGS.is_mult_view:

        left_uv_batch, left_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
            left_image_batch_resized / 255.0,
            left_proj_xyz_batch * FLAGS.uv_size / 300,
            left_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        left_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
            left_seg_batch,
            left_proj_xyz_batch,
            left_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        right_uv_batch, right_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
            right_image_batch_resized / 255.0,
            right_proj_xyz_batch * FLAGS.uv_size / 300,
            right_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        right_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
            right_seg_batch,
            right_proj_xyz_batch,
            right_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        # blend multiview
        left_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            left_uv_seg_batch)
        right_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            right_uv_seg_batch)
        front_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            front_uv_seg_batch)

        cur_seg = tf_blend_uv(
            left_uv_seg_mask_batch,
            right_uv_seg_mask_batch,
            mask_right_batch,
            match_color=False,
        )
        uv_seg_mask_batch = tf_blend_uv(cur_seg,
                                        front_uv_seg_mask_batch,
                                        mask_front_batch,
                                        match_color=False)

        mask_batch = tf.clip_by_value(
            mask_front_batch + mask_left_batch + mask_right_batch, 0, 1)
        uv_mask_batch = mask_batch * uv_seg_mask_batch
        cur_uv = tf_blend_uv(left_uv_batch,
                             right_uv_batch,
                             mask_right_batch,
                             match_color=False)
        cur_uv = tf_blend_uv(cur_uv,
                             front_uv_batch,
                             mask_front_batch,
                             match_color=False)
        uv_batch = tf_blend_uv(base_uv_batch,
                               cur_uv,
                               uv_mask_batch,
                               match_color=True)

    else:
        uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(front_uv_seg_batch)
        uv_mask_batch = front_uv_mask_batch * uv_seg_mask_batch
        uv_batch = tf_blend_uv(base_uv_batch,
                               front_uv_batch,
                               uv_mask_batch,
                               match_color=True)

    uv_batch = tf.identity(uv_batch, name="uv_tex")
    uv_seg_mask_batch = tf.identity(uv_seg_mask_batch, name="uv_seg")
    uv_mask_batch = tf.identity(uv_mask_batch, name="uv_mask")

    init_op = tf.compat.v1.global_variables_initializer()

    sess = tf.compat.v1.Session()
    if FLAGS.write_graph:
        tf.io.write_graph(sess.graph_def, "", FLAGS.pb_path, as_text=True)
        exit()
    """ load data  """
    # seg: [300,300,19], segmentation
    # diffuse: [300,300,3], diffuse images
    # proj_xyz: [N,3]
    # ver_norm: [N,3]
    info_paths = glob.glob(os.path.join(FLAGS.input_dir, "*texture.mat"))

    for info_path in info_paths:
        info = scipy.io.loadmat(info_path)

        if FLAGS.is_mult_view:
            assert info["proj_xyz"].shape[0] >= 3  # front, left, right
            if FLAGS.is_orig_img:
                front_img = info["ori_img"][0][np.newaxis, ...]
                left_img = info["ori_img"][1][np.newaxis, ...]
                right_img = info["ori_img"][2][np.newaxis, ...]
            else:
                front_img = info["diffuse"][0][np.newaxis, ...]
                left_img = info["diffuse"][1][np.newaxis, ...]
                right_img = info["diffuse"][2][np.newaxis, ...]

            uv_tex_res, uv_mask_res = sess.run(
                [uv_batch, uv_mask_batch],
                {
                    front_image_batch: front_img,
                    front_proj_xyz_batch: info["proj_xyz"][0:1, ...],
                    front_ver_norm_batch: info["ver_norm"][0:1, ...],
                    front_seg_batch: info["seg"][0:1, ...],
                    left_image_batch: left_img,
                    left_proj_xyz_batch: info["proj_xyz"][1:2, ...],
                    left_ver_norm_batch: info["ver_norm"][1:2, ...],
                    left_seg_batch: info["seg"][1:2, ...],
                    right_image_batch: right_img,
                    right_proj_xyz_batch: info["proj_xyz"][2:3, ...],
                    right_ver_norm_batch: info["ver_norm"][2:3, ...],
                    right_seg_batch: info["seg"][2:3, ...],
                },
            )
        else:
            print(info["proj_xyz"].shape[0])
            assert info["proj_xyz"].shape[0] >= 1
            if FLAGS.is_orig_img:
                front_img = info["ori_img"][0][np.newaxis, ...]
            else:
                front_img = info["diffuse"][0][np.newaxis, ...]

            uv_tex_res, uv_mask_res = sess.run(
                [uv_batch, uv_mask_batch],
                {
                    front_image_batch: front_img,
                    front_proj_xyz_batch: info["proj_xyz"][0:1, ...],
                    front_ver_norm_batch: info["ver_norm"][0:1, ...],
                    front_seg_batch: info["seg"][0:1, ...],
                },
            )

        uv_tex_res = uv_tex_res[0]
        uv_mask_res = uv_mask_res[0]

        prefix = info_path.split("/")[-1].split(".")[0]
        uv_tex_res = uv_tex_res * 255
        uv_mask_res = uv_mask_res * 255
        Image.fromarray(uv_tex_res.astype(np.uint8)).save(
            os.path.join(FLAGS.output_dir, prefix + "_tex.png"))
        Image.fromarray(np.squeeze(uv_mask_res).astype(np.uint8)).save(
            os.path.join(FLAGS.output_dir, prefix + "_mask.png"))
        sess.close()
def main(_):
    print('---- step4 start -----')
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU_NO
    print('running base:', FLAGS.output_dir)

    basis3dmm = load_3dmm_basis(FLAGS.basis3dmm_path,
                                FLAGS.uv_path,
                                is_whole_uv=True)

    # load data (all : 0-255-float32)
    img_list, pro_yx_list, preset_masks_list, base_uv, pro_xyz_list = load_from_npz(
        FLAGS.output_dir, basis3dmm)

    # set tf datas
    base_uv_batch = tf.constant(base_uv[np.newaxis, ...], name='base_uv')
    mask_mid_batch = tf.constant(
        preset_masks_list[0][np.newaxis, ..., np.newaxis], tf.float32)
    mask_left_batch = tf.constant(
        preset_masks_list[1][np.newaxis, ..., np.newaxis], tf.float32)
    mask_right_batch = tf.constant(
        preset_masks_list[2][np.newaxis, ..., np.newaxis], tf.float32)
    mask_batch = tf.clip_by_value(
        mask_mid_batch + mask_left_batch + mask_right_batch, 0, 1)

    imageH = img_list[0].shape[0]
    imageW = img_list[0].shape[1]
    assert (img_list[0].shape[0] == img_list[1].shape[0]
            and img_list[0].shape[1] == img_list[2].shape[1])
    image_mid_batch = tf.placeholder(dtype=tf.float32,
                                     shape=[1, imageH, imageW, 3],
                                     name='image_mid')
    image_left_batch = tf.placeholder(dtype=tf.float32,
                                      shape=[1, imageH, imageW, 3],
                                      name='image_left')
    image_right_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, imageH, imageW, 3],
                                       name='image_right')

    NV = basis3dmm['basis_shape'].shape[1] // 3
    proj_xyz_mid_batch = tf.placeholder(dtype=tf.float32,
                                        shape=[1, NV, 3],
                                        name='proj_xyz_mid')
    proj_xyz_left_batch = tf.placeholder(dtype=tf.float32,
                                         shape=[1, NV, 3],
                                         name='proj_xyz_left')
    proj_xyz_right_batch = tf.placeholder(dtype=tf.float32,
                                          shape=[1, NV, 3],
                                          name='proj_xyz_right')

    ver_normals_mid_batch, _ = Projector.get_ver_norm(proj_xyz_mid_batch,
                                                      basis3dmm['tri'],
                                                      'normal_mid')
    ver_normals_left_batch, _ = Projector.get_ver_norm(proj_xyz_left_batch,
                                                       basis3dmm['tri'],
                                                       'normal_left')
    ver_normals_right_batch, _ = Projector.get_ver_norm(
        proj_xyz_right_batch, basis3dmm['tri'], 'normal_right')

    uv_mid_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_mid_batch / 255.0,
                    proj_xyz_mid_batch,
                    ver_normals_mid_batch,
                    basis3dmm,
                    512)

    uv_left_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_left_batch / 255.0,
                    proj_xyz_left_batch,
                    ver_normals_left_batch,
                    basis3dmm,
                    512)

    uv_right_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_right_batch / 255.0,
                    proj_xyz_right_batch,
                    ver_normals_right_batch,
                    basis3dmm,
                    512)

    uv_left_batch.set_shape((1, 512, 512, 3))
    uv_mid_batch.set_shape((1, 512, 512, 3))
    uv_right_batch.set_shape((1, 512, 512, 3))

    # lapulasion pyramid blending

    cur_uv = tf_blend_uv(uv_left_batch,
                         uv_right_batch,
                         mask_right_batch,
                         match_color=False)
    cur_uv = tf_blend_uv(cur_uv,
                         uv_mid_batch,
                         mask_mid_batch,
                         match_color=False)
    uv_batch = tf_blend_uv(base_uv_batch / 255,
                           cur_uv,
                           mask_batch,
                           match_color=True)
    uv_batch = uv_batch * 255

    uv_batch = tf.identity(uv_batch, name='uv_tex')

    print("uv_batch: ", uv_batch.shape)

    #------------------------------------------------------------------------------------------
    # build fitting graph
    uv_bases = basis3dmm['uv']
    para_tex = tf.get_variable(shape=[1, uv_bases['basis'].shape[0]],
                               initializer=tf.zeros_initializer(),
                               name='para_tex')

    uv_rgb, uv_mask = get_region_uv_texture(uv_bases, para_tex)
    print("uv_rgb: ", uv_rgb.shape)

    # build fitting loss
    input_uv512_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, 512, 512, 3],
                                       name='gt_uv')
    tot_loss = 0.
    loss_str = 'total:{}'
    if FLAGS.photo_weight > 0:
        photo_loss = Losses.photo_loss(uv_rgb / 255.0,
                                       input_uv512_batch / 255.0, uv_mask)
        tot_loss = tot_loss + photo_loss * FLAGS.photo_weight
        loss_str = '; photo:{}'

    if FLAGS.uv_tv_weight > 0:
        uv_tv_loss = Losses.uv_tv_loss(uv_rgb / 255.0, uv_mask)
        tot_loss = tot_loss + uv_tv_loss * FLAGS.uv_tv_weight
        loss_str = loss_str + '; tv:{}'

    if FLAGS.uv_reg_tex_weight > 0:
        uv_reg_tex_loss = Losses.reg_loss(para_tex)
        tot_loss = tot_loss + uv_reg_tex_loss * FLAGS.uv_reg_tex_weight
        loss_str = loss_str + '; reg:{}'
    optim = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = optim.minimize(tot_loss)

    with tf.Session() as sess:

        if FLAGS.write_graph:
            tf.train.write_graph(sess.graph_def,
                                 '',
                                 FLAGS.pb_path,
                                 as_text=True)
            exit()
        sess.run(tf.global_variables_initializer())

        start_time = time.time()
        uv_extract,   o_uv_left_batch,  o_uv_mid_batch, o_uv_right_batch,   o_mask_right_batch= \
                sess.run( [ uv_batch, uv_left_batch,  uv_mid_batch, uv_right_batch ,mask_right_batch],
                 {
                        image_mid_batch: img_list[0][np.newaxis,...],
                        image_left_batch: img_list[1][np.newaxis,...],
                        image_right_batch: img_list[2][np.newaxis,...],

                        proj_xyz_mid_batch: pro_xyz_list[0][np.newaxis,...],
                        proj_xyz_left_batch: pro_xyz_list[1][np.newaxis,...],
                        proj_xyz_right_batch: pro_xyz_list[2][np.newaxis,...] ,
                })
        uv_extract_dump = np.copy(uv_extract)
        uv_extract = uv_extract_dump
        print('  -------------  time wrap and merge:',
              time.time() - start_time)

        # blur will help
        uv_extract = np.asarray(uv_extract[0], np.float32)
        for _ in range(FLAGS.blur_times):
            kernel = np.reshape(np.array([1.] * FLAGS.blur_kernel * FLAGS.blur_kernel, np.float32), \
                                [FLAGS.blur_kernel, FLAGS.blur_kernel]) / (FLAGS.blur_kernel * FLAGS.blur_kernel)
            uv_extract = cv2.filter2D(uv_extract, -1, kernel)
        if FLAGS.is_downsize:
            uv_extract = cv2.resize(uv_extract, (256, 256))
        uv_extract = cv2.resize(uv_extract, (512, 512))
        uv_extract = uv_extract[np.newaxis, ...]

        # iter fit textrue paras
        start_time = time.time()
        for i in range(FLAGS.train_step):
            l1, l2, l3, l4, _ = sess.run(
                [tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, train_op],
                {input_uv512_batch: uv_extract})
            if i % 50 == 0:
                print(i, loss_str.format(l1, l2, l3, l4))
        para_tex_out = sess.run(para_tex, {input_uv512_batch: uv_extract})
        print(' ------------- time fit uv paras:', time.time() - start_time)

    uv_fit, face_mask = construct(uv_bases, para_tex_out, 512)
    uv_fit = blend_uv(base_uv / 255, uv_fit / 255, face_mask, True, 5)
    uv_fit = np.clip(uv_fit * 255, 0, 255)

    # save
    prefix = "bapi"
    Image.fromarray(np.squeeze(uv_extract_dump).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_merge.png'))
    Image.fromarray(np.squeeze(uv_fit).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_fit.png'))
    Image.fromarray(np.squeeze(o_uv_mid_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_mid_batch.png'))
    Image.fromarray(np.squeeze(o_uv_left_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_left_batch.png'))
    Image.fromarray(np.squeeze(o_uv_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_right_batch.png'))
    Image.fromarray(np.squeeze(o_mask_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_mask_right_batch.png'))

    np.save(os.path.join(FLAGS.output_dir, "para_tex_init.npy"), para_tex_out)
    print('---- step4 succeed -----')