Exemplo n.º 1
0
def main():

    mask_path = "../resources/mask_used.png"
    fit_path = os.path.join(a.input_fit_dir, "out_for_texture_tex_D.png")
    pix2pix_path = os.path.join(a.input_pix2pix_dir,
                                "out_for_texture_tex_D.png")
    unwrap_path = os.path.join(a.input_unwrap_dir, "out_for_texture_tex.png")

    mask_img = (
        np.asarray(Image.open(mask_path).resize(
            (2048, 2048)), np.float32) / 255.0)
    mask_img = mask_img[..., np.newaxis]
    fit_img = np.asarray(Image.open(fit_path), np.float32)
    pix2pix_img = np.asarray(Image.open(pix2pix_path), np.float32)
    unwrap_img = np.asarray(
        Image.open(unwrap_path).resize((2048, 2048)), np.float32)

    fit_mu = np.sum(fit_img * mask_img, axis=(0, 1)) / np.sum(mask_img,
                                                              axis=(0, 1))
    pix2pix_mu = np.sum(pix2pix_img * mask_img, axis=(0, 1)) / np.sum(
        mask_img, axis=(0, 1))

    pix2pix_img = pix2pix_img - pix2pix_mu + fit_mu
    pix2pix_img = np.clip(pix2pix_img, 0, 255)

    mask_img = np.concatenate([mask_img] * 3, axis=-1)
    pix2pix_img = blend_uv(fit_img / 255,
                           pix2pix_img / 255,
                           mask_img,
                           match_color=False,
                           times=7)
    pix2pix_img = pix2pix_img * 255
    unwrap_img = blend_uv(fit_img / 255,
                          unwrap_img / 255,
                          mask_img,
                          match_color=False,
                          times=7)
    unwrap_img = unwrap_img * 255

    if os.path.exists(a.output_dir) is False:
        os.makedirs(a.output_dir)

    Image.fromarray(pix2pix_img.astype(np.uint8)).save(
        os.path.join(a.output_dir, "output_for_texture_tex_D.png"))
    Image.fromarray(unwrap_img.astype(np.uint8)).save(
        os.path.join(a.output_dir, "output_for_texture_tex.png"))
def main(_):
    print('---- step4 start -----')
    os.environ["CUDA_VISIBLE_DEVICES"] = FLAGS.GPU_NO
    print('running base:', FLAGS.output_dir)

    basis3dmm = load_3dmm_basis(FLAGS.basis3dmm_path,
                                FLAGS.uv_path,
                                is_whole_uv=True)

    # load data (all : 0-255-float32)
    img_list, pro_yx_list, preset_masks_list, base_uv, pro_xyz_list = load_from_npz(
        FLAGS.output_dir, basis3dmm)

    # set tf datas
    base_uv_batch = tf.constant(base_uv[np.newaxis, ...], name='base_uv')
    mask_mid_batch = tf.constant(
        preset_masks_list[0][np.newaxis, ..., np.newaxis], tf.float32)
    mask_left_batch = tf.constant(
        preset_masks_list[1][np.newaxis, ..., np.newaxis], tf.float32)
    mask_right_batch = tf.constant(
        preset_masks_list[2][np.newaxis, ..., np.newaxis], tf.float32)
    mask_batch = tf.clip_by_value(
        mask_mid_batch + mask_left_batch + mask_right_batch, 0, 1)

    imageH = img_list[0].shape[0]
    imageW = img_list[0].shape[1]
    assert (img_list[0].shape[0] == img_list[1].shape[0]
            and img_list[0].shape[1] == img_list[2].shape[1])
    image_mid_batch = tf.placeholder(dtype=tf.float32,
                                     shape=[1, imageH, imageW, 3],
                                     name='image_mid')
    image_left_batch = tf.placeholder(dtype=tf.float32,
                                      shape=[1, imageH, imageW, 3],
                                      name='image_left')
    image_right_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, imageH, imageW, 3],
                                       name='image_right')

    NV = basis3dmm['basis_shape'].shape[1] // 3
    proj_xyz_mid_batch = tf.placeholder(dtype=tf.float32,
                                        shape=[1, NV, 3],
                                        name='proj_xyz_mid')
    proj_xyz_left_batch = tf.placeholder(dtype=tf.float32,
                                         shape=[1, NV, 3],
                                         name='proj_xyz_left')
    proj_xyz_right_batch = tf.placeholder(dtype=tf.float32,
                                          shape=[1, NV, 3],
                                          name='proj_xyz_right')

    ver_normals_mid_batch, _ = Projector.get_ver_norm(proj_xyz_mid_batch,
                                                      basis3dmm['tri'],
                                                      'normal_mid')
    ver_normals_left_batch, _ = Projector.get_ver_norm(proj_xyz_left_batch,
                                                       basis3dmm['tri'],
                                                       'normal_left')
    ver_normals_right_batch, _ = Projector.get_ver_norm(
        proj_xyz_right_batch, basis3dmm['tri'], 'normal_right')

    uv_mid_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_mid_batch / 255.0,
                    proj_xyz_mid_batch,
                    ver_normals_mid_batch,
                    basis3dmm,
                    512)

    uv_left_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_left_batch / 255.0,
                    proj_xyz_left_batch,
                    ver_normals_left_batch,
                    basis3dmm,
                    512)

    uv_right_batch, _ = \
            unwrap_utils.unwrap_img_into_uv(
                    image_right_batch / 255.0,
                    proj_xyz_right_batch,
                    ver_normals_right_batch,
                    basis3dmm,
                    512)

    uv_left_batch.set_shape((1, 512, 512, 3))
    uv_mid_batch.set_shape((1, 512, 512, 3))
    uv_right_batch.set_shape((1, 512, 512, 3))

    # lapulasion pyramid blending

    cur_uv = tf_blend_uv(uv_left_batch,
                         uv_right_batch,
                         mask_right_batch,
                         match_color=False)
    cur_uv = tf_blend_uv(cur_uv,
                         uv_mid_batch,
                         mask_mid_batch,
                         match_color=False)
    uv_batch = tf_blend_uv(base_uv_batch / 255,
                           cur_uv,
                           mask_batch,
                           match_color=True)
    uv_batch = uv_batch * 255

    uv_batch = tf.identity(uv_batch, name='uv_tex')

    print("uv_batch: ", uv_batch.shape)

    #------------------------------------------------------------------------------------------
    # build fitting graph
    uv_bases = basis3dmm['uv']
    para_tex = tf.get_variable(shape=[1, uv_bases['basis'].shape[0]],
                               initializer=tf.zeros_initializer(),
                               name='para_tex')

    uv_rgb, uv_mask = get_region_uv_texture(uv_bases, para_tex)
    print("uv_rgb: ", uv_rgb.shape)

    # build fitting loss
    input_uv512_batch = tf.placeholder(dtype=tf.float32,
                                       shape=[1, 512, 512, 3],
                                       name='gt_uv')
    tot_loss = 0.
    loss_str = 'total:{}'
    if FLAGS.photo_weight > 0:
        photo_loss = Losses.photo_loss(uv_rgb / 255.0,
                                       input_uv512_batch / 255.0, uv_mask)
        tot_loss = tot_loss + photo_loss * FLAGS.photo_weight
        loss_str = '; photo:{}'

    if FLAGS.uv_tv_weight > 0:
        uv_tv_loss = Losses.uv_tv_loss(uv_rgb / 255.0, uv_mask)
        tot_loss = tot_loss + uv_tv_loss * FLAGS.uv_tv_weight
        loss_str = loss_str + '; tv:{}'

    if FLAGS.uv_reg_tex_weight > 0:
        uv_reg_tex_loss = Losses.reg_loss(para_tex)
        tot_loss = tot_loss + uv_reg_tex_loss * FLAGS.uv_reg_tex_weight
        loss_str = loss_str + '; reg:{}'
    optim = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = optim.minimize(tot_loss)

    with tf.Session() as sess:

        if FLAGS.write_graph:
            tf.train.write_graph(sess.graph_def,
                                 '',
                                 FLAGS.pb_path,
                                 as_text=True)
            exit()
        sess.run(tf.global_variables_initializer())

        start_time = time.time()
        uv_extract,   o_uv_left_batch,  o_uv_mid_batch, o_uv_right_batch,   o_mask_right_batch= \
                sess.run( [ uv_batch, uv_left_batch,  uv_mid_batch, uv_right_batch ,mask_right_batch],
                 {
                        image_mid_batch: img_list[0][np.newaxis,...],
                        image_left_batch: img_list[1][np.newaxis,...],
                        image_right_batch: img_list[2][np.newaxis,...],

                        proj_xyz_mid_batch: pro_xyz_list[0][np.newaxis,...],
                        proj_xyz_left_batch: pro_xyz_list[1][np.newaxis,...],
                        proj_xyz_right_batch: pro_xyz_list[2][np.newaxis,...] ,
                })
        uv_extract_dump = np.copy(uv_extract)
        uv_extract = uv_extract_dump
        print('  -------------  time wrap and merge:',
              time.time() - start_time)

        # blur will help
        uv_extract = np.asarray(uv_extract[0], np.float32)
        for _ in range(FLAGS.blur_times):
            kernel = np.reshape(np.array([1.] * FLAGS.blur_kernel * FLAGS.blur_kernel, np.float32), \
                                [FLAGS.blur_kernel, FLAGS.blur_kernel]) / (FLAGS.blur_kernel * FLAGS.blur_kernel)
            uv_extract = cv2.filter2D(uv_extract, -1, kernel)
        if FLAGS.is_downsize:
            uv_extract = cv2.resize(uv_extract, (256, 256))
        uv_extract = cv2.resize(uv_extract, (512, 512))
        uv_extract = uv_extract[np.newaxis, ...]

        # iter fit textrue paras
        start_time = time.time()
        for i in range(FLAGS.train_step):
            l1, l2, l3, l4, _ = sess.run(
                [tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, train_op],
                {input_uv512_batch: uv_extract})
            if i % 50 == 0:
                print(i, loss_str.format(l1, l2, l3, l4))
        para_tex_out = sess.run(para_tex, {input_uv512_batch: uv_extract})
        print(' ------------- time fit uv paras:', time.time() - start_time)

    uv_fit, face_mask = construct(uv_bases, para_tex_out, 512)
    uv_fit = blend_uv(base_uv / 255, uv_fit / 255, face_mask, True, 5)
    uv_fit = np.clip(uv_fit * 255, 0, 255)

    # save
    prefix = "bapi"
    Image.fromarray(np.squeeze(uv_extract_dump).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_merge.png'))
    Image.fromarray(np.squeeze(uv_fit).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_tex_fit.png'))
    Image.fromarray(np.squeeze(o_uv_mid_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_mid_batch.png'))
    Image.fromarray(np.squeeze(o_uv_left_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_left_batch.png'))
    Image.fromarray(np.squeeze(o_uv_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_uv_right_batch.png'))
    Image.fromarray(np.squeeze(o_mask_right_batch).astype(np.uint8)).save(
        os.path.join(FLAGS.output_dir, prefix + '_mask_right_batch.png'))

    np.save(os.path.join(FLAGS.output_dir, "para_tex_init.npy"), para_tex_out)
    print('---- step4 succeed -----')
def main(_):

    # save parameters
    save_params()

    mask_batch = tf.placeholder(
        dtype=tf.float32, shape=[1, 512, 512, 1], name="uv_mask"
    )
    tex_batch = tf.placeholder(dtype=tf.float32, shape=[1, 512, 512, 3], name="uv_tex")

    var_mask_batch = tf.get_variable(
        shape=[1, 512, 512, 1], dtype=tf.float32, name="var_mask", trainable=False
    )
    var_tex_batch = tf.get_variable(
        shape=[1, 512, 512, 3], dtype=tf.float32, name="var_tex", trainable=False
    )

    assign_op = tf.group(
        [tf.assign(var_mask_batch, mask_batch), tf.assign(var_tex_batch, tex_batch)],
        name="assign_op",
    )

    # arrange images by name (load 3dmm)
    basis3dmm = load_3dmm_basis(
        FLAGS.basis3dmm_path,
        FLAGS.uv_path,
        uv_weight_mask_path=FLAGS.uv_weight_mask_path,
        is_train=False,
        is_whole_uv=False,
    )

    # build fitting graph
    uv_region_bases = basis3dmm["uv"]
    para_uv_dict = {}
    for region_name in uv_region_bases:
        region_basis = uv_region_bases[region_name]
        para = tf.get_variable(
            shape=[1, region_basis["basis"].shape[0]],
            initializer=tf.zeros_initializer(),
            name="para_" + region_name,
        )
        para_uv_dict[region_name] = para

    uv_rgb, uv_mask = get_uv_texture(uv_region_bases, para_uv_dict)
    photo_weight_mask = get_weighted_photo_mask(uv_region_bases)

    # build fitting loss
    tot_loss = 0.0
    loss_str = ""

    if FLAGS.photo_weight > 0:
        photo_loss = Losses.ws_photo_loss(
            var_tex_batch, uv_rgb / 255.0, uv_mask * photo_weight_mask * var_mask_batch
        )
        photo_loss = tf.identity(photo_loss, name="photo_loss")
        tot_loss = tot_loss + photo_loss * FLAGS.photo_weight
        loss_str = "photo:{}"

    if FLAGS.uv_tv_weight > 0:
        uv_tv_loss = Losses.uv_tv_loss2(
            uv_rgb / 255, uv_mask, basis3dmm["uv_weight_mask"]
        )
        uv_tv_loss = tf.identity(uv_tv_loss, name="uv_tv_loss")
        tot_loss = tot_loss + uv_tv_loss * FLAGS.uv_tv_weight
        loss_str = loss_str + ";tv:{}"

    if FLAGS.uv_reg_tex_weight > 0:
        uv_reg_tex_loss = 0.0
        for key in para_uv_dict:
            para = para_uv_dict[key]
            reg_region_loss = Losses.reg_loss(para)
            uv_reg_tex_loss = uv_reg_tex_loss + reg_region_loss
        uv_reg_tex_loss = uv_reg_tex_loss / len(para_uv_dict)
        uv_reg_tex_loss = tf.identity(uv_reg_tex_loss, name="uv_reg_tex_loss")
        tot_loss = tot_loss + uv_reg_tex_loss * FLAGS.uv_reg_tex_weight
        loss_str = loss_str + ";reg:{}"

    optim = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    train_op = optim.minimize(tot_loss, name="train_op")
    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:

        if FLAGS.write_graph:
            tf.train.write_graph(sess.graph_def, "", FLAGS.pb_path, as_text=True)
            exit()

        uv_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, "*tex.png")))
        mask_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, "*mask.png")))
        # mask_paths = ['../resources/mask_front_face.png'] * len(uv_paths)
        # uv_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir,'*delight_512.png')))
        # mask_paths = sorted(glob.glob(os.path.join(FLAGS.data_dir, '*mask_512.png')))

        # base uv
        base_uv = Image.open("../resources/base_tex.png")
        base_uv = np.asarray(base_uv, np.float32) / 255.0
        base_normal = Image.open("../resources/base_normal.png")
        base_normal = np.asarray(base_normal, np.float32) / 255.0

        for uv_path, mask_path in zip(uv_paths, mask_paths):
            uv_input = np.asarray(Image.open(uv_path)).astype(np.float32) / 255.0
            mask_input = np.asarray(Image.open(mask_path)).astype(np.float32) / 255.0
            if mask_input.shape[0] != 512:
                mask_input = cv2.resize(mask_input, (512, 512))
            if uv_input.shape[0] != 512:
                uv_input = cv2.resize(uv_input, (512, 512))

            if len(mask_input.shape) != 3:
                mask_input = mask_input[..., np.newaxis]
            mask_input = mask_input[:, :, 0:1]

            uv_input = uv_input[np.newaxis, ...]
            mask_input = mask_input[np.newaxis, ...]

            sess.run(init_op)
            sess.run(assign_op, {tex_batch: uv_input, mask_batch: mask_input})

            for i in range(FLAGS.train_step):
                # l1, l2, l3, l4, l5, _ = sess.run([tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, uv_consistency_loss, train_op])
                l1, l2, l3, l4, _ = sess.run(
                    [tot_loss, photo_loss, uv_tv_loss, uv_reg_tex_loss, train_op]
                )
                if i == 1:
                    start_time = time.time()
                if i % 100 == 0:
                    print(i, loss_str.format(l1, l2, l3, l4))

            para_out_dict = {}
            para_out_dict = sess.run(para_uv_dict)

            face_uv, face_mask = np_get_uv_texture(basis3dmm["uv2k"], para_out_dict)
            face_normal, face_mask = np_get_uv_texture(
                basis3dmm["normal2k"], para_out_dict
            )

            face_uv = np.clip(face_uv, 0, 255)
            face_normal = np.clip(face_normal, 0, 255)
            face_mask = np.clip(face_mask, 0, 1)

            prefix = uv_path.split("/")[-1].split(".")[0]
            print(prefix)

            # Image.fromarray(face_uv.astype(np.uint8)).save(
            #    os.path.join(FLAGS.out_dir, prefix + '_face_uv.png'))

            out_uv = blend_uv(base_uv, face_uv / 255, face_mask, True)
            out_normal = blend_uv(base_normal, face_normal / 255, face_mask, False)
            out_uv = np.clip(out_uv * 255, 0, 255)
            out_normal = np.clip(out_normal * 255, 0, 255)

            out_uv = Image.fromarray(out_uv.astype(np.uint8))
            out_normal = Image.fromarray(out_normal.astype(np.uint8))
            out_uv.save(os.path.join(FLAGS.out_dir, prefix + "_D.png"))
            out_normal.save(os.path.join(FLAGS.out_dir, prefix + "_N.png"))