コード例 #1
0
def main(_):
    # load 3dmm
    basis3dmm = load_3dmm_basis(
        FLAGS.basis3dmm_path,
        FLAGS.uv_path,
        is_whole_uv=True,
    )

    if os.path.exists(FLAGS.output_dir) is False:
        os.makedirs(FLAGS.output_dir)
    """ build graph """
    front_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                 shape=[1, None, None, 3],
                                                 name="front_image")
    front_image_batch_resized = tf.image.resize(front_image_batch,
                                                (FLAGS.uv_size, FLAGS.uv_size))
    front_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                               shape=[1, None, None, 19],
                                               name="front_seg")
    front_proj_xyz_batch = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
        name="front_proj_xyz",
    )
    front_ver_norm_batch = tf.compat.v1.placeholder(
        dtype=tf.float32,
        shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
        name="front_ver_norm",
    )

    base_uv_path = "../resources/base_tex.png"
    base_uv = Image.open(base_uv_path).resize((FLAGS.uv_size, FLAGS.uv_size))
    base_uv = np.asarray(base_uv, np.float32) / 255
    base_uv_batch = tf.constant(base_uv[np.newaxis, ...], name="base_uv")

    if FLAGS.is_mult_view:
        left_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                    shape=[1, None, None, 3],
                                                    name="left_image")
        left_image_batch_resized = tf.image.resize(
            left_image_batch, (FLAGS.uv_size, FLAGS.uv_size))
        left_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                  shape=[1, None, None, 19],
                                                  name="left_seg")
        left_proj_xyz_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="left_proj_xyz",
        )
        left_ver_norm_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="left_ver_norm",
        )

        right_image_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                     shape=[1, None, None, 3],
                                                     name="right_image")
        right_image_batch_resized = tf.image.resize(
            right_image_batch, (FLAGS.uv_size, FLAGS.uv_size))
        right_seg_batch = tf.compat.v1.placeholder(dtype=tf.float32,
                                                   shape=[1, None, None, 19],
                                                   name="right_seg")
        right_proj_xyz_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="right_proj_xyz",
        )
        right_ver_norm_batch = tf.compat.v1.placeholder(
            dtype=tf.float32,
            shape=[1, basis3dmm["basis_shape"].shape[1] // 3, 3],
            name="right_ver_norm",
        )

        # read fixed blending masks for multiview
        front_mask_path = "../resources/mid_blend_mask.png"
        left_mask_path = "../resources/left_blend_mask.png"
        right_mask_path = "../resources/right_blend_mask.png"
        front_mask = (np.asarray(
            Image.open(front_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        left_mask = (np.asarray(
            Image.open(left_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        right_mask = (np.asarray(
            Image.open(right_mask_path).resize((FLAGS.uv_size, FLAGS.uv_size)),
            np.float32,
        ) / 255)
        mask_front_batch = tf.constant(front_mask[np.newaxis, ..., np.newaxis],
                                       tf.float32,
                                       name="mask_front")
        mask_left_batch = tf.constant(left_mask[np.newaxis, ..., np.newaxis],
                                      tf.float32,
                                      name="mask_left")
        mask_right_batch = tf.constant(right_mask[np.newaxis, ..., np.newaxis],
                                       tf.float32,
                                       name="mask_right")

    front_uv_batch, front_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
        front_image_batch_resized / 255.0,
        front_proj_xyz_batch * FLAGS.uv_size / 300,
        front_ver_norm_batch,
        basis3dmm,
        FLAGS.uv_size,
    )

    front_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
        front_seg_batch,
        front_proj_xyz_batch,
        front_ver_norm_batch,
        basis3dmm,
        FLAGS.uv_size,
    )

    if FLAGS.is_mult_view:

        left_uv_batch, left_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
            left_image_batch_resized / 255.0,
            left_proj_xyz_batch * FLAGS.uv_size / 300,
            left_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        left_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
            left_seg_batch,
            left_proj_xyz_batch,
            left_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        right_uv_batch, right_uv_mask_batch = unwrap_utils.unwrap_img_into_uv(
            right_image_batch_resized / 255.0,
            right_proj_xyz_batch * FLAGS.uv_size / 300,
            right_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        right_uv_seg_batch, _ = unwrap_utils.unwrap_img_into_uv(
            right_seg_batch,
            right_proj_xyz_batch,
            right_ver_norm_batch,
            basis3dmm,
            FLAGS.uv_size,
        )

        # blend multiview
        left_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            left_uv_seg_batch)
        right_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            right_uv_seg_batch)
        front_uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(
            front_uv_seg_batch)

        cur_seg = tf_blend_uv(
            left_uv_seg_mask_batch,
            right_uv_seg_mask_batch,
            mask_right_batch,
            match_color=False,
        )
        uv_seg_mask_batch = tf_blend_uv(cur_seg,
                                        front_uv_seg_mask_batch,
                                        mask_front_batch,
                                        match_color=False)

        mask_batch = tf.clip_by_value(
            mask_front_batch + mask_left_batch + mask_right_batch, 0, 1)
        uv_mask_batch = mask_batch * uv_seg_mask_batch
        cur_uv = tf_blend_uv(left_uv_batch,
                             right_uv_batch,
                             mask_right_batch,
                             match_color=False)
        cur_uv = tf_blend_uv(cur_uv,
                             front_uv_batch,
                             mask_front_batch,
                             match_color=False)
        uv_batch = tf_blend_uv(base_uv_batch,
                               cur_uv,
                               uv_mask_batch,
                               match_color=True)

    else:
        uv_seg_mask_batch = unwrap_utils.get_mask_from_seg(front_uv_seg_batch)
        uv_mask_batch = front_uv_mask_batch * uv_seg_mask_batch
        uv_batch = tf_blend_uv(base_uv_batch,
                               front_uv_batch,
                               uv_mask_batch,
                               match_color=True)

    uv_batch = tf.identity(uv_batch, name="uv_tex")
    uv_seg_mask_batch = tf.identity(uv_seg_mask_batch, name="uv_seg")
    uv_mask_batch = tf.identity(uv_mask_batch, name="uv_mask")

    init_op = tf.compat.v1.global_variables_initializer()

    sess = tf.compat.v1.Session()
    if FLAGS.write_graph:
        tf.io.write_graph(sess.graph_def, "", FLAGS.pb_path, as_text=True)
        exit()
    """ load data  """
    # seg: [300,300,19], segmentation
    # diffuse: [300,300,3], diffuse images
    # proj_xyz: [N,3]
    # ver_norm: [N,3]
    info_paths = glob.glob(os.path.join(FLAGS.input_dir, "*texture.mat"))

    for info_path in info_paths:
        info = scipy.io.loadmat(info_path)

        if FLAGS.is_mult_view:
            assert info["proj_xyz"].shape[0] >= 3  # front, left, right
            if FLAGS.is_orig_img:
                front_img = info["ori_img"][0][np.newaxis, ...]
                left_img = info["ori_img"][1][np.newaxis, ...]
                right_img = info["ori_img"][2][np.newaxis, ...]
            else:
                front_img = info["diffuse"][0][np.newaxis, ...]
                left_img = info["diffuse"][1][np.newaxis, ...]
                right_img = info["diffuse"][2][np.newaxis, ...]

            uv_tex_res, uv_mask_res = sess.run(
                [uv_batch, uv_mask_batch],
                {
                    front_image_batch: front_img,
                    front_proj_xyz_batch: info["proj_xyz"][0:1, ...],
                    front_ver_norm_batch: info["ver_norm"][0:1, ...],
                    front_seg_batch: info["seg"][0:1, ...],
                    left_image_batch: left_img,
                    left_proj_xyz_batch: info["proj_xyz"][1:2, ...],
                    left_ver_norm_batch: info["ver_norm"][1:2, ...],
                    left_seg_batch: info["seg"][1:2, ...],
                    right_image_batch: right_img,
                    right_proj_xyz_batch: info["proj_xyz"][2:3, ...],
                    right_ver_norm_batch: info["ver_norm"][2:3, ...],
                    right_seg_batch: info["seg"][2:3, ...],
                },
            )
        else:
            print(info["proj_xyz"].shape[0])
            assert info["proj_xyz"].shape[0] >= 1
            if FLAGS.is_orig_img:
                front_img = info["ori_img"][0][np.newaxis, ...]
            else:
                front_img = info["diffuse"][0][np.newaxis, ...]

            uv_tex_res, uv_mask_res = sess.run(
                [uv_batch, uv_mask_batch],
                {
                    front_image_batch: front_img,
                    front_proj_xyz_batch: info["proj_xyz"][0:1, ...],
                    front_ver_norm_batch: info["ver_norm"][0:1, ...],
                    front_seg_batch: info["seg"][0:1, ...],
                },
            )

        uv_tex_res = uv_tex_res[0]
        uv_mask_res = uv_mask_res[0]

        prefix = info_path.split("/")[-1].split(".")[0]
        uv_tex_res = uv_tex_res * 255
        uv_mask_res = uv_mask_res * 255
        Image.fromarray(uv_tex_res.astype(np.uint8)).save(
            os.path.join(FLAGS.output_dir, prefix + "_tex.png"))
        Image.fromarray(np.squeeze(uv_mask_res).astype(np.uint8)).save(
            os.path.join(FLAGS.output_dir, prefix + "_mask.png"))
        sess.close()
コード例 #2
0
def RGB_opt(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU_NO
    # load 3DMM
    if FLAGS.is_bfm is False:
        basis3dmm = load_3dmm_basis(
            FLAGS.basis3dmm_path,
            FLAGS.uv_path,
        )
        para_shape_shape = basis3dmm["basis_shape"].shape[0]
        para_tex_shape = basis3dmm["uv"]["basis"].shape[0]
    else:
        basis3dmm = load_3dmm_basis_bfm(FLAGS.basis3dmm_path)
        para_shape_shape = basis3dmm["basis_shape"].shape[0]
        para_tex_shape = basis3dmm["basis_tex"].shape[0]

    # load RGB data
    info = RGB_load.load_rgb_data(FLAGS.base_dir, FLAGS.project_type,
                                  FLAGS.num_of_img)

    imageH = info["img_list"].shape[1]
    imageW = info["img_list"].shape[2]

    # build graph
    var_list = define_variable(FLAGS.num_of_img, imageH, imageW,
                               para_shape_shape, para_tex_shape, info)

    out_list = build_RGB_opt_graph(var_list, basis3dmm, imageH, imageW)

    # summary_op
    summary_op = tf.compat.v1.summary.merge_all()
    summary_writer = tf.compat.v1.summary.FileWriter(FLAGS.summary_dir)

    if os.path.exists(FLAGS.summary_dir) is False:
        os.makedirs(FLAGS.summary_dir)
    if os.path.exists(FLAGS.out_dir) is False:
        os.makedirs(FLAGS.out_dir)

    # start opt
    config = tf.compat.v1.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction=0.5
    config.gpu_options.allow_growth = True
    with tf.compat.v1.Session(config=config) as sess:
        sess.run(tf.compat.v1.global_variables_initializer())

        import time

        starttime = time.time()

        for step in range(FLAGS.train_step):

            if (step % FLAGS.log_step == 0) | (step == FLAGS.train_step - 1):
                out_summary = sess.run(summary_op)
                summary_writer.add_summary(out_summary, step)
                print("step: " + str(step))
                endtime = time.time()
                print("time:" + str(endtime - starttime))
                starttime = time.time()

            if step == FLAGS.train_step - 1 and FLAGS.save_ply:
                print("output_final_result...")
                out_para_shape, out_ver_xyz, out_tex = sess.run([
                    out_list["para_shape"], out_list["ver_xyz"],
                    out_list["tex"]
                ])
                # output ply
                v_xyz = out_ver_xyz[0]
                if FLAGS.is_bfm is False:
                    uv_map = out_tex[0] * 255.0
                    uv_size = uv_map.shape[0]
                    v_rgb = np.zeros_like(v_xyz) + 200  # N x 3
                    for (v1, v2, v3), (t1, t2,
                                       t3) in zip(basis3dmm["tri"],
                                                  basis3dmm["tri_vt"]):
                        v_rgb[v1] = uv_map[int(
                            (1.0 - basis3dmm["vt_list"][t1][1]) * uv_size),
                                           int(basis3dmm["vt_list"][t1][0] *
                                               uv_size), ]
                        v_rgb[v2] = uv_map[int(
                            (1.0 - basis3dmm["vt_list"][t2][1]) * uv_size),
                                           int(basis3dmm["vt_list"][t2][0] *
                                               uv_size), ]
                        v_rgb[v3] = uv_map[int(
                            (1.0 - basis3dmm["vt_list"][t3][1]) * uv_size),
                                           int(basis3dmm["vt_list"][t3][0] *
                                               uv_size), ]

                    write_obj(
                        os.path.join(FLAGS.out_dir, "face.obj"),
                        v_xyz,
                        basis3dmm["vt_list"],
                        basis3dmm["tri"].astype(np.int32),
                        basis3dmm["tri_vt"].astype(np.int32),
                    )
                else:
                    v_rgb = out_tex[0] * 255.0

                write_ply(
                    os.path.join(FLAGS.out_dir, "face.ply"),
                    v_xyz,
                    basis3dmm["tri"],
                    v_rgb.astype(np.uint8),
                    True,
                )

                out_diffuse, out_proj_xyz, out_ver_norm = sess.run([
                    out_list["diffuse"], out_list["proj_xyz"],
                    out_list["ver_norm"]
                ])
                out_diffuse = out_diffuse * 255.0  # RGB 0-255
                scio.savemat(
                    os.path.join(FLAGS.out_dir, "out_for_texture.mat"),
                    {
                        "ori_img": info["img_ori_list"],  # ? x ?
                        "diffuse": out_diffuse,  # 300 x 300
                        "seg": info["seg_list"],  # 300 x 300
                        "proj_xyz": out_proj_xyz,  # in 300 x 300 img
                        "ver_norm": out_ver_norm,
                    },
                )

            sess.run(out_list["train_op"])
コード例 #3
0
def main(_):
    print('---- step4 start -----')
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU_NO
    print('running base:', FLAGS.output_dir)

    basis3dmm = load_3dmm_basis(FLAGS.basis3dmm_path,
                                FLAGS.uv_path,
                                is_whole_uv=True)

    # load data (all : 0-255-float32)
    img_list, pro_yx_list, preset_masks_list, base_uv, pro_xyz_list = load_from_npz(
        FLAGS.output_dir, basis3dmm)

    # set tf datas
    base_uv_batch = tf.constant(base_uv[np.newaxis, ...], name='base_uv')
    mask_mid_batch = tf.constant(
        preset_masks_list[0][np.newaxis, ..., np.newaxis], tf.float32)
    mask_left_batch = tf.constant(
        preset_masks_list[1][np.newaxis, ..., np.newaxis], tf.float32)
    mask_right_batch = tf.constant(
        preset_masks_list[2][np.newaxis, ..., np.newaxis], tf.float32)
    mask_batch = tf.clip_by_value(
        mask_mid_batch + mask_left_batch + mask_right_batch, 0, 1)

    imageH = img_list[0].shape[0]
    imageW = img_list[0].shape[1]
    assert (img_list[0].shape[0] == img_list[1].shape[0]
            and img_list[0].shape[1] == img_list[2].shape[1])
    image_mid_batch = tf.placeholder(dtype=tf.float32,
                                     shape=[1, imageH, imageW, 3],
                                     name='image_mid')
    image_left_batch = tf.placeholder(dtype=tf.float32,
                                      shape=[1, imageH, imageW, 3],
                                      name='image_left')
    image_right_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, imageH, imageW, 3],
                                       name='image_right')

    NV = basis3dmm['basis_shape'].shape[1] // 3
    proj_xyz_mid_batch = tf.placeholder(dtype=tf.float32,
                                        shape=[1, NV, 3],
                                        name='proj_xyz_mid')
    proj_xyz_left_batch = tf.placeholder(dtype=tf.float32,
                                         shape=[1, NV, 3],
                                         name='proj_xyz_left')
    proj_xyz_right_batch = tf.placeholder(dtype=tf.float32,
                                          shape=[1, NV, 3],
                                          name='proj_xyz_right')

    ver_normals_mid_batch, _ = Projector.get_ver_norm(proj_xyz_mid_batch,
                                                      basis3dmm['tri'],
                                                      'normal_mid')
    ver_normals_left_batch, _ = Projector.get_ver_norm(proj_xyz_left_batch,
                                                       basis3dmm['tri'],
                                                       'normal_left')
    ver_normals_right_batch, _ = Projector.get_ver_norm(
        proj_xyz_right_batch, basis3dmm['tri'], 'normal_right')

    uv_mid_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_mid_batch / 255.0,
                    proj_xyz_mid_batch,
                    ver_normals_mid_batch,
                    basis3dmm,
                    512)

    uv_left_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_left_batch / 255.0,
                    proj_xyz_left_batch,
                    ver_normals_left_batch,
                    basis3dmm,
                    512)

    uv_right_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_right_batch / 255.0,
                    proj_xyz_right_batch,
                    ver_normals_right_batch,
                    basis3dmm,
                    512)

    uv_left_batch.set_shape((1, 512, 512, 3))
    uv_mid_batch.set_shape((1, 512, 512, 3))
    uv_right_batch.set_shape((1, 512, 512, 3))

    # lapulasion pyramid blending

    cur_uv = tf_blend_uv(uv_left_batch,
                         uv_right_batch,
                         mask_right_batch,
                         match_color=False)
    cur_uv = tf_blend_uv(cur_uv,
                         uv_mid_batch,
                         mask_mid_batch,
                         match_color=False)
    uv_batch = tf_blend_uv(base_uv_batch / 255,
                           cur_uv,
                           mask_batch,
                           match_color=True)
    uv_batch = uv_batch * 255

    uv_batch = tf.identity(uv_batch, name='uv_tex')

    print("uv_batch: ", uv_batch.shape)

    #------------------------------------------------------------------------------------------
    # build fitting graph
    uv_bases = basis3dmm['uv']
    para_tex = tf.get_variable(shape=[1, uv_bases['basis'].shape[0]],
                               initializer=tf.zeros_initializer(),
                               name='para_tex')

    uv_rgb, uv_mask = get_region_uv_texture(uv_bases, para_tex)
    print("uv_rgb: ", uv_rgb.shape)

    # build fitting loss
    input_uv512_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, 512, 512, 3],
                                       name='gt_uv')
    tot_loss = 0.
    loss_str = 'total:{}'
    if FLAGS.photo_weight > 0:
        photo_loss = Losses.photo_loss(uv_rgb / 255.0,
                                       input_uv512_batch / 255.0, uv_mask)
        tot_loss = tot_loss + photo_loss * FLAGS.photo_weight
        loss_str = '; photo:{}'

    if FLAGS.uv_tv_weight > 0:
        uv_tv_loss = Losses.uv_tv_loss(uv_rgb / 255.0, uv_mask)
        tot_loss = tot_loss + uv_tv_loss * FLAGS.uv_tv_weight
        loss_str = loss_str + '; tv:{}'

    if FLAGS.uv_reg_tex_weight > 0:
        uv_reg_tex_loss = Losses.reg_loss(para_tex)
        tot_loss = tot_loss + uv_reg_tex_loss * FLAGS.uv_reg_tex_weight
        loss_str = loss_str + '; reg:{}'
    optim = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = optim.minimize(tot_loss)

    with tf.Session() as sess:

        if FLAGS.write_graph:
            tf.train.write_graph(sess.graph_def,
                                 '',
                                 FLAGS.pb_path,
                                 as_text=True)
            exit()
        sess.run(tf.global_variables_initializer())

        start_time = time.time()
        uv_extract,   o_uv_left_batch,  o_uv_mid_batch, o_uv_right_batch,   o_mask_right_batch= \
                sess.run( [ uv_batch, uv_left_batch,  uv_mid_batch, uv_right_batch ,mask_right_batch],
                 {
                        image_mid_batch: img_list[0][np.newaxis,...],
                        image_left_batch: img_list[1][np.newaxis,...],
                        image_right_batch: img_list[2][np.newaxis,...],

                        proj_xyz_mid_batch: pro_xyz_list[0][np.newaxis,...],
                        proj_xyz_left_batch: pro_xyz_list[1][np.newaxis,...],
                        proj_xyz_right_batch: pro_xyz_list[2][np.newaxis,...] ,
                })
        uv_extract_dump = np.copy(uv_extract)
        uv_extract = uv_extract_dump
        print('  -------------  time wrap and merge:',
              time.time() - start_time)

        # blur will help
        uv_extract = np.asarray(uv_extract[0], np.float32)
        for _ in range(FLAGS.blur_times):
            kernel = np.reshape(np.array([1.] * FLAGS.blur_kernel * FLAGS.blur_kernel, np.float32), \
                                [FLAGS.blur_kernel, FLAGS.blur_kernel]) / (FLAGS.blur_kernel * FLAGS.blur_kernel)
            uv_extract = cv2.filter2D(uv_extract, -1, kernel)
        if FLAGS.is_downsize:
            uv_extract = cv2.resize(uv_extract, (256, 256))
        uv_extract = cv2.resize(uv_extract, (512, 512))
        uv_extract = uv_extract[np.newaxis, ...]

        # iter fit textrue paras
        start_time = time.time()
        for i in range(FLAGS.train_step):
            l1, l2, l3, l4, _ = sess.run(
                [tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, train_op],
                {input_uv512_batch: uv_extract})
            if i % 50 == 0:
                print(i, loss_str.format(l1, l2, l3, l4))
        para_tex_out = sess.run(para_tex, {input_uv512_batch: uv_extract})
        print(' ------------- time fit uv paras:', time.time() - start_time)

    uv_fit, face_mask = construct(uv_bases, para_tex_out, 512)
    uv_fit = blend_uv(base_uv / 255, uv_fit / 255, face_mask, True, 5)
    uv_fit = np.clip(uv_fit * 255, 0, 255)

    # save
    prefix = "bapi"
    Image.fromarray(np.squeeze(uv_extract_dump).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_merge.png'))
    Image.fromarray(np.squeeze(uv_fit).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_fit.png'))
    Image.fromarray(np.squeeze(o_uv_mid_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_mid_batch.png'))
    Image.fromarray(np.squeeze(o_uv_left_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_left_batch.png'))
    Image.fromarray(np.squeeze(o_uv_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_right_batch.png'))
    Image.fromarray(np.squeeze(o_mask_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_mask_right_batch.png'))

    np.save(os.path.join(FLAGS.output_dir, "para_tex_init.npy"), para_tex_out)
    print('---- step4 succeed -----')
コード例 #4
0
def RGBD_opt(_):
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU_NO

    # load data
    if FLAGS.is_bfm is False:
        basis3dmm = load_3dmm_basis(
            FLAGS.basis3dmm_path,
            FLAGS.uv_path,
        )
    else:
        basis3dmm = load_3dmm_basis_bfm(FLAGS.basis3dmm_path)

    # load_RGBD_data, sequence index is: mid --  left -- right -- up
    info = RGBD_load.load_and_preprocess_RGBD_data(
        FLAGS.prefit_dir, FLAGS.prepare_dir, basis3dmm
    )

    imageH = info["height"]
    imageW = info["width"]
    para_shape_shape = info["para_shape"].shape[1]
    para_tex_shape = info["para_tex"].shape[1]

    # build graph
    var_list = define_variable(
        FLAGS.num_of_img, imageH, imageW, para_shape_shape, para_tex_shape, info
    )

    out_list = build_RGBD_opt_graph(var_list, basis3dmm, imageH, imageW)

    # summary_op
    summary_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter(FLAGS.summary_dir)

    if os.path.exists(FLAGS.summary_dir) is False:
        os.makedirs(FLAGS.summary_dir)
    if os.path.exists(FLAGS.out_dir) is False:
        os.makedirs(FLAGS.out_dir)

    # start opt
    config = tf.ConfigProto()
    # config.gpu_options.per_process_gpu_memory_fraction=0.5
    config.gpu_options.allow_growth = True
    with tf.Session(config=config) as sess:
        sess.run(tf.global_variables_initializer())

        import time

        starttime = time.time()

        for step in range(FLAGS.train_step):

            if (step % FLAGS.log_step == 0) | (step == FLAGS.train_step - 1):
                out_summary = sess.run(summary_op)
                summary_writer.add_summary(out_summary, step)
                print("step: " + str(step))
                endtime = time.time()
                print("time:" + str(endtime - starttime))
                starttime = time.time()

            if step == FLAGS.train_step - 1 and FLAGS.save_ply:
                print("output_final_result...")
                out_para_shape, out_ver_xyz, out_tex = sess.run(
                    [out_list["para_shape"], out_list["ver_xyz"], out_list["tex"]]
                )
                # output ply
                v_xyz = out_ver_xyz[0]
                if FLAGS.is_bfm is False:
                    uv_map = out_tex[0] * 255.0
                    uv_size = uv_map.shape[0]
                    v_rgb = np.zeros_like(v_xyz) + 200  # N x 3
                    for (v1, v2, v3), (t1, t2, t3) in zip(
                        basis3dmm["tri"], basis3dmm["tri_vt"]
                    ):
                        v_rgb[v1] = uv_map[
                            int((1.0 - basis3dmm["vt_list"][t1][1]) * uv_size),
                            int(basis3dmm["vt_list"][t1][0] * uv_size),
                        ]
                        v_rgb[v2] = uv_map[
                            int((1.0 - basis3dmm["vt_list"][t2][1]) * uv_size),
                            int(basis3dmm["vt_list"][t2][0] * uv_size),
                        ]
                        v_rgb[v3] = uv_map[
                            int((1.0 - basis3dmm["vt_list"][t3][1]) * uv_size),
                            int(basis3dmm["vt_list"][t3][0] * uv_size),
                        ]

                    write_obj(
                        os.path.join(FLAGS.out_dir, "face.obj"),
                        v_xyz,
                        basis3dmm["vt_list"],
                        basis3dmm["tri"].astype(np.int32),
                        basis3dmm["tri_vt"].astype(np.int32),
                    )
                else:
                    v_rgb = out_tex[0] * 255.0

                write_ply(
                    os.path.join(FLAGS.out_dir, "face.ply"),
                    v_xyz,
                    basis3dmm["tri"],
                    v_rgb.astype(np.uint8),
                    True,
                )

                ## add head
                if FLAGS.is_bfm is False:
                    print("-------------------start add head-------------------")
                    HeadModel = np.load(
                        FLAGS.info_for_add_head, allow_pickle=True
                    ).item()
                    vertex = read_obj(os.path.join(FLAGS.out_dir, "face.obj"))
                    vertex = vertex.transpose()
                    vertex_fit_h = vertex[:, HeadModel["head_h_idx"]]
                    pca_info_h = AddHeadTool.transfer_PCA_format_for_add_head(
                        basis3dmm, HeadModel
                    )
                    vertex_output_coord = AddHeadTool.fix_back_head(
                        vertex_fit_h,
                        HeadModel,
                        pca_info_h,
                        FLAGS.is_add_head_mirrow,
                        FLAGS.is_add_head_male,
                    )
                    write_obj(
                        os.path.join(FLAGS.out_dir, "head.obj"),
                        vertex_output_coord.transpose(),
                        HeadModel["head_vt_list"],
                        HeadModel["head_tri"],
                        HeadModel["head_tri_vt"],
                    )
                    print("-------------------add head successfully-------------------")

                out_diffuse, out_proj_xyz, out_ver_norm = sess.run(
                    [out_list["diffuse"], out_list["proj_xyz"], out_list["ver_norm"]]
                )
                out_diffuse = out_diffuse * 255.0  # RGB 0-255
                scio.savemat(
                    os.path.join(FLAGS.out_dir, "out_for_texture.mat"),
                    {
                        "ori_img": info["ori_img"],  # ? x ?
                        "diffuse": out_diffuse,  # 300 x 300
                        "seg": info["seg_list"],  # 300 x 300
                        "proj_xyz": out_proj_xyz,  # in 300 x 300 img
                        "ver_norm": out_ver_norm,
                    },
                )

            sess.run(out_list["train_op"])
コード例 #5
0
def main(_):

    # save parameters
    save_params()

    mask_batch = tf.placeholder(
        dtype=tf.float32, shape=[1, 512, 512, 1], name="uv_mask"
    )
    tex_batch = tf.placeholder(dtype=tf.float32, shape=[1, 512, 512, 3], name="uv_tex")

    var_mask_batch = tf.get_variable(
        shape=[1, 512, 512, 1], dtype=tf.float32, name="var_mask", trainable=False
    )
    var_tex_batch = tf.get_variable(
        shape=[1, 512, 512, 3], dtype=tf.float32, name="var_tex", trainable=False
    )

    assign_op = tf.group(
        [tf.assign(var_mask_batch, mask_batch), tf.assign(var_tex_batch, tex_batch)],
        name="assign_op",
    )

    # arrange images by name (load 3dmm)
    basis3dmm = load_3dmm_basis(
        FLAGS.basis3dmm_path,
        FLAGS.uv_path,
        uv_weight_mask_path=FLAGS.uv_weight_mask_path,
        is_train=False,
        is_whole_uv=False,
    )

    # build fitting graph
    uv_region_bases = basis3dmm["uv"]
    para_uv_dict = {}
    for region_name in uv_region_bases:
        region_basis = uv_region_bases[region_name]
        para = tf.get_variable(
            shape=[1, region_basis["basis"].shape[0]],
            initializer=tf.zeros_initializer(),
            name="para_" + region_name,
        )
        para_uv_dict[region_name] = para

    uv_rgb, uv_mask = get_uv_texture(uv_region_bases, para_uv_dict)
    photo_weight_mask = get_weighted_photo_mask(uv_region_bases)

    # build fitting loss
    tot_loss = 0.0
    loss_str = ""

    if FLAGS.photo_weight > 0:
        photo_loss = Losses.ws_photo_loss(
            var_tex_batch, uv_rgb / 255.0, uv_mask * photo_weight_mask * var_mask_batch
        )
        photo_loss = tf.identity(photo_loss, name="photo_loss")
        tot_loss = tot_loss + photo_loss * FLAGS.photo_weight
        loss_str = "photo:{}"

    if FLAGS.uv_tv_weight > 0:
        uv_tv_loss = Losses.uv_tv_loss2(
            uv_rgb / 255, uv_mask, basis3dmm["uv_weight_mask"]
        )
        uv_tv_loss = tf.identity(uv_tv_loss, name="uv_tv_loss")
        tot_loss = tot_loss + uv_tv_loss * FLAGS.uv_tv_weight
        loss_str = loss_str + ";tv:{}"

    if FLAGS.uv_reg_tex_weight > 0:
        uv_reg_tex_loss = 0.0
        for key in para_uv_dict:
            para = para_uv_dict[key]
            reg_region_loss = Losses.reg_loss(para)
            uv_reg_tex_loss = uv_reg_tex_loss + reg_region_loss
        uv_reg_tex_loss = uv_reg_tex_loss / len(para_uv_dict)
        uv_reg_tex_loss = tf.identity(uv_reg_tex_loss, name="uv_reg_tex_loss")
        tot_loss = tot_loss + uv_reg_tex_loss * FLAGS.uv_reg_tex_weight
        loss_str = loss_str + ";reg:{}"

    optim = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = optim.minimize(tot_loss, name="train_op")
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:

        if FLAGS.write_graph:
            tf.train.write_graph(sess.graph_def, "", FLAGS.pb_path, as_text=True)
            exit()

        uv_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, "*tex.png")))
        mask_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, "*mask.png")))
        # mask_paths = ['../resources/mask_front_face.png'] * len(uv_paths)
        # uv_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir,'*delight_512.png')))
        # mask_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, '*mask_512.png')))

        # base uv
        base_uv = Image.open("../resources/base_tex.png")
        base_uv = np.asarray(base_uv, np.float32) / 255.0
        base_normal = Image.open("../resources/base_normal.png")
        base_normal = np.asarray(base_normal, np.float32) / 255.0

        for uv_path, mask_path in zip(uv_paths, mask_paths):
            uv_input = np.asarray(Image.open(uv_path)).astype(np.float32) / 255.0
            mask_input = np.asarray(Image.open(mask_path)).astype(np.float32) / 255.0
            if mask_input.shape[0] != 512:
                mask_input = cv2.resize(mask_input, (512, 512))
            if uv_input.shape[0] != 512:
                uv_input = cv2.resize(uv_input, (512, 512))

            if len(mask_input.shape) != 3:
                mask_input = mask_input[..., np.newaxis]
            mask_input = mask_input[:, :, 0:1]

            uv_input = uv_input[np.newaxis, ...]
            mask_input = mask_input[np.newaxis, ...]

            sess.run(init_op)
            sess.run(assign_op, {tex_batch: uv_input, mask_batch: mask_input})

            for i in range(FLAGS.train_step):
                # l1, l2, l3, l4, l5, _ = sess.run([tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, uv_consistency_loss, train_op])
                l1, l2, l3, l4, _ = sess.run(
                    [tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, train_op]
                )
                if i == 1:
                    start_time = time.time()
                if i % 100 == 0:
                    print(i, loss_str.format(l1, l2, l3, l4))

            para_out_dict = {}
            para_out_dict = sess.run(para_uv_dict)

            face_uv, face_mask = np_get_uv_texture(basis3dmm["uv2k"], para_out_dict)
            face_normal, face_mask = np_get_uv_texture(
                basis3dmm["normal2k"], para_out_dict
            )

            face_uv = np.clip(face_uv, 0, 255)
            face_normal = np.clip(face_normal, 0, 255)
            face_mask = np.clip(face_mask, 0, 1)

            prefix = uv_path.split("/")[-1].split(".")[0]
            print(prefix)

            # Image.fromarray(face_uv.astype(np.uint8)).save(
            #    os.path.join(FLAGS.out_dir, prefix + '_face_uv.png'))

            out_uv = blend_uv(base_uv, face_uv / 255, face_mask, True)
            out_normal = blend_uv(base_normal, face_normal / 255, face_mask, False)
            out_uv = np.clip(out_uv * 255, 0, 255)
            out_normal = np.clip(out_normal * 255, 0, 255)

            out_uv = Image.fromarray(out_uv.astype(np.uint8))
            out_normal = Image.fromarray(out_normal.astype(np.uint8))
            out_uv.save(os.path.join(FLAGS.out_dir, prefix + "_D.png"))
            out_normal.save(os.path.join(FLAGS.out_dir, prefix + "_N.png"))