import torch
from torch import nn
import math
import scipy.sparse
import numpy as np
from torch.nn.parameter import Parameter
from device import device
from psbody.mesh import Mesh
from graphlib import graph, coarsening, utils, mesh_sampling
import pickle

reference_mesh_file = "./data/template.obj"
ds_factors = [4, 4, 4, 4]
# Generates adjecency matrices A, downsampling matrices D, and upsamling matrices U by sampling
# the mesh 4 times. Each time the mesh is sampled by a factor of 4
reference_mesh = Mesh(filename=reference_mesh_file)
M, A, D, U = mesh_sampling.generate_transform_matrices(reference_mesh,
                                                       ds_factors)
pickle.dump([M, A, D, U], open("./data/pai_template.pkl", 'wb'))

#%%
from psbody.mesh import Mesh
import torch

reference_mesh_file = "./data/template.obj"
reference_mesh = Mesh(filename=reference_mesh_file)
mean = torch.load('./data/Processed/sliced/mean.tch')
reference_mesh.v = mean.numpy()
reference_mesh.show()

#%%
Beispiel #2
0
 def vec2mesh(self, vec):
     vec = vec.reshape((self.n_vertex, 3)) * self.std + self.mean
     return Mesh(v=vec, f=self.reference_mesh.f)
Beispiel #3
0




parser = argparse.ArgumentParser()
parser.add_argument('dir', type=str,)
args = parser.parse_args()
name = ' '.join((args.dir).split('/')).split()[-1]
os.makedirs(os.path.join(args.dir,'unwraps'),exist_ok=True)
step1_make_unwraps.main(os.path.join(args.dir,'frame_data.pkl'),os.path.join(args.dir,'frames'),os.path.join(args.dir,'segmentations'),os.path.join(args.dir,'unwraps'))
step2_segm_vote_gmm.main(os.path.join(args.dir,'unwraps'),os.path.join(args.dir,'segm.png'),os.path.join(args.dir,'gmm.pkl'))
step3_stitch_texture.main(os.path.join(args.dir,'unwraps'), os.path.join(args.dir,'segm.png'), os.path.join(args.dir,'gmm.pkl'), os.path.join(args.dir,name+'_octopus.jpg'),20)

filename = os.path.join(args.dir,name+'.obj')
body = Mesh(filename=filename.replace(name+'.obj',name+'_octopus.obj'))
body_tex = filename.replace('.obj', '_octopus.jpg')
if not os.path.exists(body_tex):
    body_tex = 'tex_{}'.format(body_tex)

v, f = body.v, body.f
(mapping, hf) = loop_subdivider(v, f)
hv = mapping.dot(v.ravel()).reshape(-1, 3)
body_hres = Mesh(hv, hf)

vt, ft = np.hstack((body.vt, np.ones((body.vt.shape[0], 1)))), body.ft
(mappingt, hft) = loop_subdivider(vt, ft)
hvt = mappingt.dot(vt.ravel()).reshape(-1, 3)[:, :2]
body_hres.vt, body_hres.ft = hvt, hft

body_hres.set_texture_image(body_tex)
Beispiel #4
0
def run_2d_lmk_fitting(tf_model_fname, template_fname, flame_lmk_path,
                       texture_mapping, target_img_path, target_lmk_path,
                       out_path):
    if 'generic' not in tf_model_fname:
        print(
            'You are fitting a gender specific model (i.e. female / male). Please make sure you selected the right gender model. Choose the generic model if gender is unknown.'
        )
    if not os.path.exists(template_fname):
        print('Template mesh (in FLAME topology) not found - %s' %
              template_fname)
        return
    if not os.path.exists(flame_lmk_path):
        print('FLAME landmark embedding not found - %s ' % flame_lmk_path)
        return
    if not os.path.exists(target_img_path):
        print('Target image not found - s' % target_img_path)
        return
    if not os.path.exists(target_lmk_path):
        print('Landmarks of target image not found - s' % target_lmk_path)
        return

    if not os.path.exists(out_path):
        os.makedirs(out_path)

    lmk_face_idx, lmk_b_coords = load_embedding(flame_lmk_path)

    target_img = cv2.imread(target_img_path)
    lmk_2d = np.load(target_lmk_path)

    weights = {}
    # Weight of the landmark distance term
    weights['lmk'] = 1.0
    # Weight of the shape regularizer
    weights['shape'] = 1e-3
    # Weight of the expression regularizer
    weights['expr'] = 1e-3
    # Weight of the neck pose (i.e. neck rotationh around the neck) regularizer
    weights['neck_pose'] = 100.0
    # Weight of the jaw pose (i.e. jaw rotation for opening the mouth) regularizer
    weights['jaw_pose'] = 1e-3
    # Weight of the eyeball pose (i.e. eyeball rotations) regularizer
    weights['eyeballs_pose'] = 10.0

    result_mesh, result_scale = fit_lmk2d(target_img, lmk_2d, template_fname,
                                          tf_model_fname, lmk_face_idx,
                                          lmk_b_coords, weights)

    if sys.version_info >= (3, 0):
        texture_data = np.load(texture_mapping,
                               allow_pickle=True,
                               encoding='latin1').item()
    else:
        texture_data = np.load(texture_mapping, allow_pickle=True).item()
    texture_map = compute_texture_map(target_img, result_mesh, result_scale,
                                      texture_data)

    out_mesh_fname = os.path.join(
        out_path,
        os.path.splitext(os.path.basename(target_img_path))[0] + '.obj')
    out_img_fname = os.path.join(
        out_path,
        os.path.splitext(os.path.basename(target_img_path))[0] + '.png')

    cv2.imwrite(out_img_fname, texture_map)
    result_mesh.set_vertex_colors('white')
    result_mesh.vt = texture_data['vt']
    result_mesh.ft = texture_data['ft']
    result_mesh.set_texture_image(out_img_fname)
    result_mesh.write_obj(out_mesh_fname)
    np.save(
        os.path.join(
            out_path,
            os.path.splitext(os.path.basename(target_img_path))[0] +
            '_scale.npy'), result_scale)

    mv = MeshViewers(shape=[1, 2], keepalive=True)
    mv[0][0].set_static_meshes([Mesh(result_mesh.v, result_mesh.f)])
    mv[0][1].set_static_meshes([result_mesh])
def render_mesh_helper(mesh, t_center, rot=np.zeros(3), v_colors=None, errors=None, error_unit='m', min_dist_in_mm=0.0, max_dist_in_mm=3.0, z_offset=0):
    camera_params = {'c': np.array([400, 400]),
                     'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
                     'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])}

    frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}

    mesh_copy = Mesh(mesh.v, mesh.f)
    # mesh_copy.v[:] = np.matmul(mesh_copy.v[:], rot)
    # mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot((mesh_copy.v-t_center).T).T+t_center

    if errors is not None:
        intensity = 0.5
        unit_factor = get_unit_factor('mm')/get_unit_factor(error_unit)
        errors = unit_factor*errors

        norm = mpl.colors.Normalize(vmin=min_dist_in_mm, vmax=max_dist_in_mm)
        cmap = cm.get_cmap(name='jet')
        colormapper = cm.ScalarMappable(norm=norm, cmap=cmap)
        rgba_per_v = colormapper.to_rgba(errors)
        rgb_per_v = rgba_per_v[:, 0:3]
    elif v_colors is not None:
        intensity = 0.5
        rgb_per_v = v_colors
    else:
        intensity = 1.5
        rgb_per_v = None

    tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v, faces=mesh_copy.f, vertex_colors=rgb_per_v)
    render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, smooth=True)

    scene = pyrender.Scene(ambient_light=[.2, .2, .2], bg_color=[255, 255, 255])
    camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
                                      fy=camera_params['f'][1],
                                      cx=camera_params['c'][0],
                                      cy=camera_params['c'][1],
                                      znear=frustum['near'],
                                      zfar=frustum['far'])

    scene.add(render_mesh, pose=np.eye(4))

    camera_pose = np.eye(4)
    camera_pose[:3,3] = np.array([0, 0, 1.0-z_offset])
    scene.add(camera, pose=[[1, 0, 0, 0],
                            [0, 1, 0, 0],
                            [0, 0, 1, 1],
                            [0, 0, 0, 1]])

    angle = np.pi / 6.0
    pos = camera_pose[:3,3]
    light_color = np.array([1., 1., 1.])
    light = pyrender.PointLight(color=light_color, intensity=intensity)

    light_pose = np.eye(4)
    light_pose[:3,3] = pos
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] =  cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3,3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    flags = pyrender.RenderFlags.SKIP_CULL_FACES
    r = pyrender.OffscreenRenderer(viewport_width=frustum['width'], viewport_height=frustum['height'])
    color, _ = r.render(scene, flags=flags)

    return color[..., ::-1]
Beispiel #6
0
def fit_sources(
    dir_tup_list,
    tf_model_fname,
    template_fname,
    weight_reg_shape,
    weight_reg_expr,
    weight_reg_neck_pos,
    weight_reg_jaw_pos,
    weight_reg_eye_pos,
    showing=False
):
    global g_mv
    if showing:
        g_mv = MeshViewer()

    saver = tf.train.import_meta_graph(tf_model_fname + '.meta')

    graph = tf.get_default_graph()
    tf_model = graph.get_tensor_by_name(u'vertices:0')

    with tf.Session() as session:
        saver.restore(session, tf_model_fname)

        template = Mesh(filename=template_fname)
        tf_src = tf.Variable(tf.zeros(template.v.shape, dtype=tf.float64))

        # get all params
        tf_trans = [x for x in tf.trainable_variables() if 'trans' in x.name][0]
        tf_rot   = [x for x in tf.trainable_variables() if 'rot'   in x.name][0]
        tf_pose  = [x for x in tf.trainable_variables() if 'pose'  in x.name][0]
        tf_shape = [x for x in tf.trainable_variables() if 'shape' in x.name][0]
        tf_exp   = [x for x in tf.trainable_variables() if 'exp'   in x.name][0]

        def _save_state(*names, **kwargs):
            state = dict()
            if "trans" in names: state["trans"] = tf_trans.eval()
            if "rot"   in names: state["rot"]   = tf_rot.eval()
            if "pose"  in names: state["pose"]  = tf_pose.eval()
            if "shape" in names: state["shape"] = tf_shape.eval()
            if "exp"   in names: state["exp"]   = tf_exp.eval()
            if kwargs.get("set_zero", False):
                _zero_state(*names)
            return state

        def _load_state(state):
            ops = []
            if "trans" in state: ops.append(tf_trans.assign(state["trans"]))
            if "rot"   in state: ops.append(tf_rot.assign  (state["rot"]  ))
            if "pose"  in state: ops.append(tf_pose.assign (state["pose"] ))
            if "shape" in state: ops.append(tf_shape.assign(state["shape"]))
            if "exp"   in state: ops.append(tf_exp.assign  (state["exp"]  ))
            session.run(ops)

        def _zero_state(*names):
            ops = []
            if "trans" in names: ops.append(tf_trans.assign(tf.zeros_like(tf_trans)))
            if "rot"   in names: ops.append(tf_rot  .assign(tf.zeros_like(tf_rot  )))
            if "pose"  in names: ops.append(tf_pose .assign(tf.zeros_like(tf_pose )))
            if "shape" in names: ops.append(tf_shape.assign(tf.zeros_like(tf_shape)))
            if "exp"   in names: ops.append(tf_exp  .assign(tf.zeros_like(tf_exp  )))
            session.run(ops)

        mesh_dist     = tf.reduce_sum(tf.square(tf.subtract(tf_model, tf_src)))
        neck_pose_reg = tf.reduce_sum(tf.square(tf_pose[:3]))
        jaw_pose_reg  = tf.reduce_sum(tf.square(tf_pose[3:6]))
        eye_pose_reg  = tf.reduce_sum(tf.square(tf_pose[6:]))
        shape_reg     = tf.reduce_sum(tf.square(tf_shape))
        exp_reg       = tf.reduce_sum(tf.square(tf_exp))
        reg_term = (
            weight_reg_neck_pos * neck_pose_reg +
            weight_reg_jaw_pos  * jaw_pose_reg  +
            weight_reg_eye_pos  * eye_pose_reg  +
            weight_reg_shape    * shape_reg     +
            weight_reg_expr     * exp_reg
        )

        # optimizers
        optim_shared_rigid = scipy_pt(
            loss=mesh_dist,
            var_list=[tf_trans, tf_rot],
            method='L-BFGS-B',
            options={'disp': 0}
        )
        optim_shared_all = scipy_pt(
            loss=mesh_dist+reg_term,
            var_list=[tf_trans, tf_rot, tf_pose, tf_shape, tf_exp],
            method='L-BFGS-B',
            options={'disp': 0}
        )
        optim_seq = scipy_pt(
            loss=mesh_dist+reg_term,
            var_list=[tf_shape, tf_exp],
            method='L-BFGS-B', options={'disp': 0, 'maxiter': 50}
        )

        def _fit_sentence(src_dir, dst_dir, prm_dir, last_speaker):
            _anchor = os.path.join(dst_dir, "_anchor")
            if os.path.exists(_anchor):
                print("- Skip " + src_dir)
                return
            if not os.path.exists(src_dir):
                print("- Failed to find " + src_dir)
                return
            if not os.path.exists(dst_dir): os.makedirs(dst_dir)
            if not os.path.exists(prm_dir): os.makedirs(prm_dir)

            ply_files = []
            for root, _, files in os.walk(src_dir):
                for f in files:
                    if os.path.splitext(f)[1] == ".ply":
                        ply_files.append(os.path.join(root, f))
            ply_files = sorted(ply_files)

            # get shared
            src_mesh = Mesh(filename=ply_files[0])
            session.run(tf.assign(tf_src, src_mesh.v))

            speaker = os.path.basename(os.path.dirname(src_dir))

            if last_speaker != speaker:
                print("- clear speaker information")
                _zero_state("trans", "rot", "pose", "shape", "exp")
            else:
                _zero_state("exp")

            stt_dir = os.path.join(os.path.dirname(dst_dir), "state")
            if os.path.exists(stt_dir):
                state_dict = dict(
                    trans  = np.load(os.path.join(stt_dir, "trans.npy")),
                    rot    = np.load(os.path.join(stt_dir, "rot.npy")),
                    pose   = np.load(os.path.join(stt_dir, "pose.npy")),
                    shape  = np.load(os.path.join(stt_dir, "shape.npy")),
                )
                _load_state(state_dict)

                fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
                fitting_mesh.write_ply(os.path.join(stt_dir, "zero.ply"))
            fit_zero_dir = os.path.join(os.path.dirname(os.path.dirname(dst_dir)), "zero_exp")
            if not os.path.exists(fit_zero_dir): os.makedirs(fit_zero_dir)

            print("- " + speaker + " " + os.path.basename(src_dir))
            print("  -> fit shared parameters...")
            optim_shared_rigid.minimize(session)
            optim_shared_all.minimize(session)

            state_dict = _save_state("exp", set_zero=True)

            fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
            fitting_mesh.write_ply(os.path.join(fit_zero_dir, "{}.ply".format(speaker)))

            _load_state(state_dict)

            return

            if not os.path.exists(stt_dir): os.makedirs(stt_dir)
            np.save(os.path.join(stt_dir, "trans.npy"), tf_trans.eval(), allow_pickle=False)
            np.save(os.path.join(stt_dir, "rot.npy"),   tf_rot.eval(),   allow_pickle=False)
            np.save(os.path.join(stt_dir, "pose.npy"),  tf_pose.eval(),  allow_pickle=False)
            np.save(os.path.join(stt_dir, "shape.npy"), tf_shape.eval(), allow_pickle=False)

            progress = tqdm(ply_files)
            for src_fname in progress:
                frame = os.path.basename(src_fname)
                progress.set_description("  -> " + frame)
                dst_fname = os.path.join(dst_dir, frame)
                # param filename
                prm_fname = os.path.join(prm_dir, frame)
                exp_fname = os.path.splitext(prm_fname)[0] + '_exp.npy'
                idn_fname = os.path.splitext(prm_fname)[0] + '_idn.npy'

                src_mesh = Mesh(filename=src_fname)
                session.run(tf.assign(tf_src, src_mesh.v))

                optim_seq.minimize(session)

                # save expr
                np.save(exp_fname, tf_exp.eval())
                np.save(idn_fname, tf_shape.eval())

                # state_dict = _save_state("trans", "rot", "pose", "shape", set_zero=True)

                # save mesh
                fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
                fitting_mesh.write_ply(dst_fname)

                # _load_state(state_dict)
                # print(tf_shape.eval())

                if showing:
                    g_mv.set_static_meshes([fitting_mesh])

            os.system("touch {}".format(_anchor))
            return speaker

        last_speaker = None
        for (src, dst, prm) in dir_tup_list:
            last_speaker = _fit_sentence(src, dst, prm, last_speaker)
Beispiel #7
0
def cage(length=1, vc=name_to_rgb['black']):

    cage_points = np.array([[-1., -1., -1.], [1., 1., 1.], [1., -1., 1.],
                            [-1., 1., -1.]])
    c = Mesh(v=length * cage_points, f=[], vc=vc)
    return c
Beispiel #8
0
    col[(verts[:, 1] < -1.14) & (verts[:, 0] >= 0)] = 1

    if display:
        ms.set_vertex_colors_from_weights(col)
        ms.show()

    print('left_foot ', np.where(col)[0].shape)
    return col


if __name__ == "__main__":
    smplx = get_tpose_smplx()
    verts = smplx.vertices.detach().cpu().numpy().squeeze()
    faces = smplx.faces

    ms = Mesh(v=verts, f=faces)

    col = np.zeros((10475, ))
    display = False
    rfa = cut_right_forearm(display)
    col += (rfa * 0.1)

    rma = cut_right_midarm(display)
    col += (rma * 0.2)

    lfa = cut_left_forearm(display)
    col += (lfa * 0.3)

    lma = cut_left_midarm(display)
    col += (lma * 0.4)
Beispiel #9
0
        avail_items = f.read().splitlines()
    avail_items = [k.split('\t') for k in avail_items]
    people_names = [k[0] for k in avail_items if k[1] == garment_class]
    shape_root = os.path.join(global_var.ROOT, 'neutral_shape_static_pose_new')
    smoothing = None

    shape_names = ["{:02d}".format(k) for k in range(0, 100)]

    for people_name, garment_class in tqdm(avail_items):
        shape_static_pose_people = os.path.join(shape_root, people_name)
        for shape_name in shape_names:
            garment_path = os.path.join(shape_static_pose_people, '{}_{}.obj'.format(shape_name, garment_class))
            if not os.path.exists(garment_path):
                print("{} doesn't exist".format(garment_path))
            try:
                m = Mesh(filename=garment_path)
            except AttributeError as e:
                print(e)
                print(garment_path)
                exit()

            if smoothing is None:
                smoothing = DiffusionSmoothing(m.v, m.f, Ltype="cotangent")
            steps = [30]

            verts_smooth = m.v.copy()
            for i, step in enumerate(steps):
                smooth_name = '_sm{}'.format(step)
                dst_path = os.path.join(shape_static_pose_people, '{}{}_{}.obj'.format(shape_name, smooth_name, garment_class))
                if os.path.exists(dst_path):
                    print("{} exists. Skip".format(dst_path))
    data = np.ones(nverts)
    data[indices] *= ww
    I = csr_matrix((data, (rc, rc)), shape=(nverts, nverts))

    A = vstack([L, I])
    b = np.vstack((L.dot(mesh.v), tgt_points))

    res = spsolve(A.T.dot(A), A.T.dot(b))
    mres = Mesh(v=res, f=mesh.f)
    return mres


if __name__ == '__main__':
    import os
    ROOT = "/BS/cpatel/work/data/learn_anim/mixture_exp31/000_0/smooth_TShirtNoCoat/0990/"
    body = Mesh(filename=os.path.join(ROOT, "body_160.ply"))
    mesh = Mesh(filename=os.path.join(ROOT, "pred_160.ply"))

    mesh1 = remove_interpenetration_fast(mesh, body)
    mesh1.write_ply("/BS/cpatel/work/proccessed.ply")
    mesh.write_ply("/BS/cpatel/work/orig.ply")
    body.write_ply("/BS/cpatel/work/body.ply")

    # from psbody.mesh import MeshViewers
    # mvs = MeshViewers((1, 2))
    # mesh1.set_vertex_colors_from_weights(np.linalg.norm(mesh.v - mesh1.v, axis=1))
    # mesh.set_vertex_colors_from_weights(np.linalg.norm(mesh.v - mesh1.v, axis=1))
    # # mesh1.set_vertex_colors_from_weights(np.zeros(mesh.v.shape[0]))
    # # mesh.set_vertex_colors_from_weights(np.zeros(mesh.v.shape[0]))
    # mvs[0][0].set_static_meshes([mesh, body])
    # mvs[0][1].set_static_meshes([mesh1, body])
Beispiel #11
0
def run_tailornet():
    gender = 'male'
    garment_class = 'pant'
    #garment_class = 'shirt'
    garment_combine = True
    garment_class_pairs = {
        #'pant': ['shirt', 't-shirt'],
        'pant': ['shirt'],
        'short-pant': ['shirt', 't-shirt'],
        'skirt': ['shirt', 't-shirt']
    }
    thetas, betas, gammas = get_single_frame_inputs(garment_class, gender)
    # # uncomment the line below to run inference on sequence data
    # thetas, betas, gammas = get_sequence_inputs(garment_class, gender)

    # load model
    tn_runner = get_tn_runner(gender=gender, garment_class=garment_class)
    # from trainer.base_trainer import get_best_runner
    # tn_runner = get_best_runner("/BS/cpatel/work/data/learn_anim/tn_baseline/{}_{}/".format(garment_class, gender))
    smpl = SMPL4Garment(gender=gender)

    # make out directory if doesn't exist
    if not os.path.isdir(OUT_PATH):
        os.mkdir(OUT_PATH)

    # run inference
    for i, (theta, beta, gamma) in enumerate(zip(thetas, betas, gammas)):
        print(i, len(thetas))
        # normalize y-rotation to make it front facing
        theta_normalized = normalize_y_rotation(theta)
        with torch.no_grad():
            pred_verts_d = tn_runner.forward(
                thetas=torch.from_numpy(theta_normalized[None, :].astype(
                    np.float32)).cuda(),
                betas=torch.from_numpy(beta[None, :].astype(
                    np.float32)).cuda(),
                gammas=torch.from_numpy(gamma[None, :].astype(
                    np.float32)).cuda(),
            )[0].cpu().numpy()

        # get garment from predicted displacements
        body, pred_gar = smpl.run(beta=beta,
                                  theta=theta,
                                  garment_class=garment_class,
                                  garment_d=pred_verts_d)

        # gar_pair_hres = Mesh(
        #     filename=os.path.join(OUT_PATH,
        #                           "{}_gar_hres_{}_{:04d}.obj".format(gender, garment_class_pairs[garment_class][0], i)))
        # gar_pair_hres = remove_interpenetration_fast(gar_pair_hres, body)
        # pred_body = remove_interpenetration_fast_custom(body, gar_pair_hres, threshold=0.0010)
        # pred_body.write_obj(os.path.join(OUT_PATH, 'deformed_body_shirt_0001.obj'))
        #
        pred_gar = remove_interpenetration_fast(pred_gar, body)
        pred_body, _ = remove_interpenetration_fast_custom(body,
                                                           pred_gar,
                                                           threshold=0.0005)
        pred_body.write_obj(
            os.path.join(OUT_PATH, 'deformed_body_pant_shirt.obj'))

        hv, hf, mapping = get_hres(pred_gar.v, pred_gar.f)
        pred_gar_hres = Mesh(hv, hf)
        if garment_combine:
            for garment_class_pair in garment_class_pairs[garment_class]:
                gar_pair_hres = Mesh(filename=os.path.join(
                    OUT_PATH, "{}_gar_hres_{}_{:04d}.obj".format(
                        gender, garment_class_pair, i)))
                print(len(body.v), body.f.min(), body.f.max())
                print(len(pred_gar_hres.v), pred_gar_hres.f.min(),
                      pred_gar_hres.f.max())
                print(
                    np.vstack((body.v, pred_gar_hres.v)).shape,
                    np.min(pred_gar_hres.f + len(body.v)),
                    np.max(pred_gar_hres.f + len(body.v)))
                #gar_pair_hres = remove_interpenetration_fast(gar_pair_hres, pred_gar_hres)

                for _ in range(3):
                    gar_pair_hres = remove_interpenetration_fast(
                        gar_pair_hres, pred_body)
                for _ in range(3):
                    gar_pair_hres = remove_interpenetration_fast(
                        gar_pair_hres,
                        Mesh(
                            np.vstack((body.v, pred_gar_hres.v)),
                            np.vstack(
                                (body.f, pred_gar_hres.f + len(body.v)))))

                pred_gar_hres_inverse, _ = remove_interpenetration_fast_custom(
                    pred_gar_hres,
                    Mesh(np.vstack((body.v, gar_pair_hres.v)),
                         np.vstack((body.f, gar_pair_hres.f + len(body.v)))),
                    threshold=0.000005,
                    inverse=True)

                gar_pair_hres.write_obj(
                    os.path.join(
                        OUT_PATH,
                        "new2_{}_gar_hres_{}_inter_{}_{:04d}.obj".format(
                            gender, garment_class_pair, garment_class, i)))

        # save body and predicted garment
        body.write_obj(
            os.path.join(OUT_PATH, "{}_body_{:04d}.obj".format(gender, i)))
        pred_gar.write_obj(
            os.path.join(
                OUT_PATH,
                "{}_gar_{}_{:04d}.obj".format(gender, garment_class, i)))
        pred_gar_hres.write_obj(
            os.path.join(
                OUT_PATH,
                "{}_gar_hres_{}_{:04d}.obj".format(gender, garment_class, i)))
        pred_gar_hres_inverse.write_obj(
            os.path.join(
                OUT_PATH, "{}_gar_hres_{}_{:04d}_inverse.obj".format(
                    gender, garment_class, i)))
Beispiel #12
0
def vis_results(ho,
                dorig,
                coarse_net,
                refine_net,
                rh_model,
                save=False,
                save_dir=None,
                rh_model_pkl=None,
                vis=True):

    # with torch.no_grad():
    imw, imh = 1920, 780
    cols = len(dorig['bps_object'])
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')

    if vis:
        mvs = MeshViewers(window_width=imw,
                          window_height=imh,
                          shape=[1, cols],
                          keepalive=True)

    # drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
    #
    # for k in drec_cnet.keys():
    #     print('drec cnet', k, drec_cnet[k].shape)

    # verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    drec_cnet = {}

    hand_pose_in = torch.Tensor(ho.hand_pose[3:]).unsqueeze(0)
    mano_out_1 = rh_model_pkl(hand_pose=hand_pose_in)
    hand_pose_in = mano_out_1.hand_pose

    mTc = torch.Tensor(ho.hand_mTc)
    approx_global_orient = rotmat2aa(mTc[:3, :3].unsqueeze(0))

    if torch.isnan(approx_global_orient).any():  # Using honnotate?
        approx_global_orient = torch.Tensor(ho.hand_pose[:3]).unsqueeze(0)

    approx_global_orient = approx_global_orient.squeeze(1).squeeze(1)
    approx_trans = mTc[:3, 3].unsqueeze(0)

    target_verts = torch.Tensor(ho.hand_verts).unsqueeze(0)

    pose, trans, rot = util.opt_hand(rh_model, target_verts, hand_pose_in,
                                     approx_trans, approx_global_orient)

    # drec_cnet['hand_pose'] = torch.einsum('bi,ij->bj', [hand_pose_in, rh_model_pkl.hand_components])
    drec_cnet['transl'] = trans
    drec_cnet['global_orient'] = rot
    drec_cnet['hand_pose'] = pose

    verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                   dorig['verts_object'].to(device))

    drec_cnet['trans_rhand_f'] = drec_cnet['transl']
    drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
        drec_cnet['global_orient']).view(-1, 3, 3)
    drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(
        -1, 15, 3, 3)
    drec_cnet['verts_object'] = dorig['verts_object'].to(device)
    drec_cnet['h2o_dist'] = h2o.abs()

    print(
        'Hand fitting err',
        np.linalg.norm(
            verts_rh_gen_cnet.squeeze().detach().numpy() - ho.hand_verts, 2,
            1).mean())
    orig_obj = dorig['mesh_object'][0].v
    # print(orig_obj.shape, orig_obj)
    # print('Obj fitting err', np.linalg.norm(orig_obj - ho.obj_verts, 2, 1).mean())

    drec_rnet = refine_net(**drec_cnet)
    mano_out = rh_model(**drec_rnet)
    verts_rh_gen_rnet = mano_out.vertices
    joints_out = mano_out.joints

    if vis:
        for cId in range(0, len(dorig['bps_object'])):
            try:
                from copy import deepcopy
                meshes = deepcopy(dorig['mesh_object'])
                obj_mesh = meshes[cId]
            except:
                obj_mesh = points_to_spheres(to_cpu(
                    dorig['verts_object'][cId]),
                                             radius=0.002,
                                             vc=name_to_rgb['green'])

            hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['pink'])
            hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['gray'])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_cnet.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)
                # print('rotmat', rotmat)

            hand_mesh_gen_cnet.reset_face_normals()
            hand_mesh_gen_rnet.reset_face_normals()

            # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True)
            # mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True)
            mvs[0][cId].set_static_meshes(
                [hand_mesh_gen_rnet, hand_mesh_gen_cnet, obj_mesh],
                blocking=True)

            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.write_ply(filename=save_path +
                                             '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh[0].write_ply(filename=save_path +
                                      '/obj_mesh_%d.ply' % cId)

    return verts_rh_gen_rnet, joints_out
def estimate_global_pose(landmarks,
                         key_vids,
                         model,
                         cam,
                         img,
                         fix_t=False,
                         viz=False,
                         SOLVE_FLATER=True):
    '''
    Estimates the global rotation and translation.
    only diff in estimate_global_pose from single_frame_ferrari is that all animals have the same kp order.
    '''
    # Redefining part names..
    part_names = [
        'leftEye', 'rightEye', 'chin', 'frontLeftFoot', 'frontRightFoot',
        'backLeftFoot', 'backRightFoot', 'tailStart', 'frontLeftKnee',
        'frontRightKnee', 'backLeftKnee', 'backRightKnee', 'leftShoulder',
        'rightShoulder', 'frontLeftAnkle', 'frontRightAnkle', 'backLeftAnkle',
        'backRightAnkle', 'neck', 'TailTip'
    ]

    # Use shoulder to "knee"(elbow) distance. also tail to "knee" if available.
    use_names = [
        'neck', 'leftShoulder', 'rightShoulder', 'backLeftKnee',
        'backRightKnee', 'tailStart', 'frontLeftKnee', 'frontRightKnee'
    ]
    use_ids = [part_names.index(name) for name in use_names]
    # These might not be visible
    visible = landmarks[:, 2].astype(bool)
    use_ids = [id for id in use_ids if visible[id]]
    if len(use_ids) < 3:
        print('Frontal?..')
        use_names += [
            'frontLeftAnkle', 'frontRightAnkle', 'backLeftAnkle',
            'backRightAnkle'
        ]
        model.pose[1] = np.pi / 2

    init_t = estimate_translation(landmarks, key_vids, cam.f[0].r, model)

    use_ids = [part_names.index(name) for name in use_names]
    use_ids = [id for id in use_ids if visible[id]]

    # Setup projection error:
    all_vids = np.hstack([key_vids[id] for id in use_ids])
    cam.v = model[all_vids]

    keypoints = landmarks[use_ids, :2].astype(float)

    # Duplicate keypoints for the # of vertices for that kp.
    num_verts_per_kp = [len(key_vids[row_id]) for row_id in use_ids]
    j2d = np.vstack([
        np.tile(kp, (num_rep, 1))
        for kp, num_rep in zip(keypoints, num_verts_per_kp)
    ])

    assert (cam.r.shape == j2d.shape)

    # SLOW but correct method!!
    # remember which ones belongs together,,
    group = np.hstack([
        index * np.ones(len(key_vids[row_id]))
        for index, row_id in enumerate(use_ids)
    ])
    assignments = np.vstack([group == i for i in np.arange(group[-1] + 1)])
    num_points = len(use_ids)
    proj_error = (ch.vstack([
        cam[choice] if np.sum(choice) == 1 else cam[choice].mean(axis=0)
        for choice in assignments
    ]) - keypoints) / np.sqrt(num_points)

    # Fast but not matching average:
    # Normalization weight
    j2d_norm_weights = np.sqrt(
        1. / len(use_ids) *
        np.vstack([1. / num * np.ones((num, 1)) for num in num_verts_per_kp]))
    proj_error_fast = j2d_norm_weights * (cam - j2d)

    if fix_t:
        obj = {'cam': proj_error_fast}
    else:
        obj = {
            'cam': proj_error_fast,
            'cam_t': 1e1 * (model.trans[2] - init_t[2])
        }

    # Only estimate body orientation
    if fix_t:
        free_variables = [model.pose[:3]]
    else:
        free_variables = [model.trans, model.pose[:3]]

    if not SOLVE_FLATER:
        obj['feq'] = 1e3 * (cam.f[0] - cam.f[1])
        # So it's under control
        obj['freg'] = 1e1 * (cam.f[0] - 3000) / 1000.
        # here without this cam.f goes negative.. (asking margin of 500)
        obj['fpos'] = ch.maximum(0, 500 - cam.f[0])
        # cam t also has to be positive!
        obj['cam_t_pos'] = ch.maximum(0, 0.01 - model.trans[2])
        del obj['cam_t']
        free_variables.append(cam.f)

    if viz:
        import matplotlib.pyplot as plt
        plt.ion()

        def on_step(_):
            plt.figure(1, figsize=(5, 5))
            plt.cla()
            plt.imshow(img[:, :, ::-1])
            img_here = render_mesh(Mesh(model.r, model.f), img.shape[1],
                                   img.shape[0], cam)
            plt.imshow(img_here)
            plt.scatter(j2d[:, 0], j2d[:, 1], c='w')
            plt.scatter(cam.r[:, 0], cam.r[:, 1])
            plt.draw()
            plt.pause(1e-3)
            if 'feq' in obj:
                print('flength %.1f %.1f, z %.f' %
                      (cam.f[0], cam.f[1], model.trans[2]))
    else:
        on_step = None

    from time import time
    t0 = time()
    init_angles = [[0, 0, 0]]  #, [1.5,0,0], [1.5,-1.,0]]
    scores = np.zeros(len(init_angles))
    for i, angle in enumerate(init_angles):
        # Init translation
        model.trans[:] = init_t
        model.pose[:3] = angle
        ch.minimize(obj,
                    x0=free_variables,
                    method='dogleg',
                    callback=on_step,
                    options={
                        'maxiter': 100,
                        'e_3': .0001
                    })
        scores[i] = np.sum(obj['cam'].r**2.)
    j = np.argmin(scores)
    model.trans[:] = init_t
    model.pose[:3] = init_angles[j]
    ch.minimize(obj,
                x0=free_variables,
                method='dogleg',
                callback=on_step,
                options={
                    'maxiter': 100,
                    'e_3': .0001
                })

    print('Took %g' % (time() - t0))

    #import pdb; pdb.set_trace()

    if viz:
        dist = np.mean(model.r, axis=0)[2]
        img_here = render_mesh(Mesh(model.r, model.f), img.shape[1],
                               img.shape[0], cam)
        plt.imshow(img[:, :, ::-1])
        plt.imshow(img_here)

    return model.pose[:3].r, model.trans.r
            np.float32)).unsqueeze(0)
        betas_torch = torch.from_numpy(beta[:10].astype(
            np.float32)).unsqueeze(0)

        smpl_verts = smpl_torch.forward(pose_torch, betas_torch)
        transform = smpl_torch.A.detach().numpy()[0]
        transform_inv = np.array(
            [np.linalg.inv(transform[i]) for i in range(24)])

        #rotate root joint using inverse transform
        joints = smpl_torch.J_transformed.detach().numpy()[0]
        root_joint = np.array([joints[0, 0], joints[0, 1], joints[0, 2], 1])
        transformed_root = np.array(
            [np.matmul(transform_inv[i], root_joint)[:3] for i in range(24)])

        m1 = Mesh(v=smpl_verts.detach().numpy()[0], f=smpl_torch.faces)
        mesh = trimesh.Trimesh(m1.v, smpl_torch.faces)

        #sample points on mesh surface and displace with sigma = 0.03
        boundary_points = []
        points = mesh.sample(sample_num)
        boundary_points_1 = points + 0.03 * np.random.randn(sample_num, 3)

        #sample points in bbox of 110%
        bottom_corner, upper_corner = mesh.bounds
        bottom_corner = bottom_corner + 0.1 * bottom_corner
        upper_corner = upper_corner + 0.1 * upper_corner
        x_pt = np.random.uniform(bottom_corner[0], upper_corner[0], sample_num)
        y_pt = np.random.uniform(bottom_corner[1], upper_corner[1], sample_num)
        z_pt = np.random.uniform(bottom_corner[2], upper_corner[2], sample_num)
        boundary_points_2 = np.concatenate(
    args = parser.parse_args()

    if args.conf is None:
        args.conf = os.path.join(os.path.dirname(__file__), 'default.cfg')
        print(
            'configuration file not specified, trying to load '
            'it from current directory', args.conf)

    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
Beispiel #16
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices, val_vertices, test_vertices = [], [], []
        for idx, data_file in tqdm(enumerate(self.data_file)):
            mesh = Mesh(filename=data_file)
            mesh_verts = torch.Tensor(mesh.v)
            adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
            # edge_index = torch.LongTensor(np.vstack((adjacency.row, adjacency.col)))
            edge_index = torch.Tensor(np.vstack((adjacency.row, adjacency.col)))
            data = Data(x=mesh_verts, y=mesh_verts, edge_index=edge_index)

            if self.split == 'sliced':
                if idx % 100 <= 10:
                    test_data.append(data)
                    test_vertices.append(mesh.v)
                elif idx % 100 <= 20:
                    val_data.append(data)
                    val_vertices.append(mesh.v)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'expression':
                if data_file.split('/')[-2] == self.split_term:
                    test_data.append(data)
                    test_vertices.append(mesh.v)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)

            elif self.split == 'identity':
                if data_file.split('/')[-3] == self.split_term:
                    test_data.append(data)
                    test_vertices.append(mesh.v)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
            elif self.split == 'custom':
                test_terms = self.split_term.split('-')
                test_flag = False
                for term in test_terms:
                    if term in data_file:
                        test_flag = True
                if test_flag is True:
                    # print('testing set: {}'.format(data_file))
                    test_data.append(data)
                    test_vertices.append(mesh.v)
                else:
                    # print('training set: {}'.format(data_file))
                    train_data.append(data)
                    train_vertices.append(mesh.v)                    
            else:
                raise Exception('sliced, expression and identity are the only supported split terms')

        if self.split != 'sliced':
            val_data = test_data[-self.nVal:]
            test_data = test_data[:-self.nVal]
            val_vertices = test_vertices[-self.nVal:]
            test_vertices = test_vertices[:-self.nVal]

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            print('Transforming data...')
            if hasattr(self.pre_transform, 'mean') and hasattr(self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        print('Saving data...')
        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])

        self.save_vertices(np.array(train_vertices), np.array(val_vertices), np.array(test_vertices))
Beispiel #17
0
def fit_lmk2d(target_img, target_2d_lmks, model_fname, lmk_face_idx, lmk_b_coords, weights, visualize):
    '''
    Fit FLAME to 2D landmarks
    :param target_2d_lmks      target 2D landmarks provided as (num_lmks x 3) matrix
    :param model_fname         saved FLAME model
    :param lmk_face_idx        face indices of the landmark embedding in the FLAME topology
    :param lmk_b_coords        barycentric coordinates of the landmark embedding in the FLAME topology
                                (i.e. weighting of the three vertices for the trinagle, the landmark is embedded in
    :param weights             weights of the individual objective functions
    :param visualize           visualize fitting progress
    :return: a mesh with the fitting results
    '''

    '''
    pred_types = {'face': pred_type(slice(0, 17), (0.682, 0.780, 0.909, 0.5)),
              'eyebrow1': pred_type(slice(17, 22), (1.0, 0.498, 0.055, 0.4)),
              'eyebrow2': pred_type(slice(22, 27), (1.0, 0.498, 0.055, 0.4)),
              'nose': pred_type(slice(27, 31), (0.345, 0.239, 0.443, 0.4)),
              'nostril': pred_type(slice(31, 36), (0.345, 0.239, 0.443, 0.4)),
              'eye1': pred_type(slice(36, 42), (0.596, 0.875, 0.541, 0.3)),
              'eye2': pred_type(slice(42, 48), (0.596, 0.875, 0.541, 0.3)),
              'lips': pred_type(slice(48, 60), (0.596, 0.875, 0.541, 0.3)),
              'teeth': pred_type(slice(60, 68), (0.596, 0.875, 0.541, 0.4))
              }
    '''

    lmks_weights = [[1,1]] * 68
    for idx in range(36, 48):
      lmks_weights[idx] = [100, 100]

    tf_lmks_weights = tf.constant(
        lmks_weights,
        tf.float64
    )

    tf_trans = tf.Variable(np.zeros((1,3)), name="trans", dtype=tf.float64, trainable=True)
    tf_rot = tf.Variable(np.zeros((1,3)), name="rot", dtype=tf.float64, trainable=True)
    tf_pose = tf.Variable(np.zeros((1,12)), name="pose", dtype=tf.float64, trainable=True)
    tf_shape = tf.Variable(np.zeros((1,300)), name="shape", dtype=tf.float64, trainable=True)
    tf_exp = tf.Variable(np.zeros((1,100)), name="expression", dtype=tf.float64, trainable=True)
    smpl = SMPL(model_fname)
    tf_model = tf.squeeze(smpl(tf_trans,
                               tf.concat((tf_shape, tf_exp), axis=-1),
                               tf.concat((tf_rot, tf_pose), axis=-1)))

    with tf.Session() as session:
        # session.run(tf.global_variables_initializer())

        # Mirror landmark y-coordinates
        target_2d_lmks[:,1] = target_img.shape[0]-target_2d_lmks[:,1]

        lmks_3d = tf_get_model_lmks(tf_model, smpl.f, lmk_face_idx, lmk_b_coords)

        s2d = np.mean(np.linalg.norm(target_2d_lmks-np.mean(target_2d_lmks, axis=0), axis=1))
        s3d = tf.reduce_mean(tf.sqrt(tf.reduce_sum(tf.square(lmks_3d-tf.reduce_mean(lmks_3d, axis=0))[:, :2], axis=1)))
        tf_scale = tf.Variable(s2d/s3d, dtype=lmks_3d.dtype)

        # trans = 0.5*np.array((target_img.shape[0], target_img.shape[1]))/tf_scale
        # trans = 0.5 * s3d * np.array((target_img.shape[0], target_img.shape[1])) / s2d
        lmks_proj_2d = tf_project_points(lmks_3d, tf_scale, np.zeros(2))

        factor = max(max(target_2d_lmks[:,0]) - min(target_2d_lmks[:,0]),max(target_2d_lmks[:,1]) - min(target_2d_lmks[:,1]))
        #lmk_dist = weights['lmk']*tf.reduce_sum(tf.square(tf.subtract(lmks_proj_2d, target_2d_lmks))) / (factor ** 2)
        lmk_dist = weights['lmk']*tf.reduce_sum(
            tf.square(tf.subtract(lmks_proj_2d, target_2d_lmks)) * tf_lmks_weights
        ) / (factor ** 2)

        neck_pose_reg = weights['neck_pose']*tf.reduce_sum(tf.square(tf_pose[:,:3]))
        jaw_pose_reg = weights['jaw_pose']*tf.reduce_sum(tf.square(tf_pose[:,3:6]))
        eyeballs_pose_reg = weights['eyeballs_pose']*tf.reduce_sum(tf.square(tf_pose[:,6:]))
        shape_reg = weights['shape']*tf.reduce_sum(tf.square(tf_shape))
        exp_reg = weights['expr']*tf.reduce_sum(tf.square(tf_exp))

        session.run(tf.global_variables_initializer())

        if visualize:
            def on_step(verts, scale, faces, target_img, target_lmks, opt_lmks, lmk_dist=0.0, shape_reg=0.0, exp_reg=0.0, neck_pose_reg=0.0, jaw_pose_reg=0.0, eyeballs_pose_reg=0.0):
                import cv2
                import sys
                import numpy as np
                from psbody.mesh import Mesh
                from utils.render_mesh import render_mesh

                if lmk_dist>0.0 or shape_reg>0.0 or exp_reg>0.0 or neck_pose_reg>0.0 or jaw_pose_reg>0.0 or eyeballs_pose_reg>0.0:
                    print('lmk_dist: %f, shape_reg: %f, exp_reg: %f, neck_pose_reg: %f, jaw_pose_reg: %f, eyeballs_pose_reg: %f' % (lmk_dist, shape_reg, exp_reg, neck_pose_reg, jaw_pose_reg, eyeballs_pose_reg))

                plt_target_lmks = target_lmks.copy()
                plt_target_lmks[:, 1] = target_img.shape[0] - plt_target_lmks[:, 1]
                for (x, y) in plt_target_lmks:
                    cv2.circle(target_img, (int(x), int(y)), 4, (0, 0, 255), -1)

                plt_opt_lmks = opt_lmks.copy()
                plt_opt_lmks[:,1] = target_img.shape[0] - plt_opt_lmks[:,1]
                for (x, y) in plt_opt_lmks:
                    cv2.circle(target_img, (int(x), int(y)), 4, (255, 0, 0), -1)

                if sys.version_info >= (3, 0):
                    rendered_img = render_mesh(Mesh(scale*verts, faces), height=target_img.shape[0], width=target_img.shape[1])
                    for (x, y) in plt_opt_lmks:
                        cv2.circle(rendered_img, (int(x), int(y)), 4, (255, 0, 0), -1)
                    target_img = np.hstack((target_img, rendered_img))

                #cv2.imshow('img', target_img)
                #cv2.waitKey(10)
        else:
            def on_step(*_):
                pass

        print('Optimize rigid transformation')
        vars = [tf_scale, tf_trans, tf_rot]
        loss = lmk_dist
        optimizer = scipy_pt(loss=loss, var_list=vars, method='L-BFGS-B', options={'disp': 1, 'ftol': 5e-6})
        optimizer.minimize(session, fetches=[tf_model, tf_scale, tf.constant(smpl.f), tf.constant(target_img), tf.constant(target_2d_lmks), lmks_proj_2d], loss_callback=on_step)

        print('Optimize model parameters')
        vars = [tf_scale, tf_trans[:2], tf_rot, tf_pose, tf_shape, tf_exp]
        loss = lmk_dist + shape_reg + exp_reg + neck_pose_reg + jaw_pose_reg + eyeballs_pose_reg

        optimizer = scipy_pt(loss=loss, var_list=vars, method='L-BFGS-B', options={'disp': 0, 'ftol': 1e-7})
        optimizer.minimize(session, fetches=[tf_model, tf_scale, tf.constant(smpl.f), tf.constant(target_img), tf.constant(target_2d_lmks), lmks_proj_2d,
                                             lmk_dist, shape_reg, exp_reg, neck_pose_reg, jaw_pose_reg, eyeballs_pose_reg], loss_callback=on_step)

        print('Fitting done')
        np_verts, np_scale = session.run([tf_model, tf_scale])
        return Mesh(np_verts, smpl.f), np_scale
Beispiel #18
0
def fit_lmk3d(target_3d_lmks,
              template_fname,
              tf_model_fname,
              lmk_face_idx,
              lmk_b_coords,
              weights,
              show_fitting=True):
    '''
    Fit FLAME to 3D landmarks
    :param target_3d_lmks:      target 3D landmarks provided as (num_lmks x 3) matrix
    :param template_fname:      template mesh in FLAME topology (only the face information are used)
    :param tf_model_fname:      saved Tensorflow FLAME model
    :param lmk_face_idx:        face indices of the landmark embedding in the FLAME topology
    :param lmk_b_coords:        barycentric coordinates of the landmark embedding in the FLAME topology
                                (i.e. weighting of the three vertices for the trinagle, the landmark is embedded in
    :param weights:             weights of the individual objective functions
    :return: a mesh with the fitting results
    '''

    template_mesh = Mesh(filename=template_fname)
    saver = tf.train.import_meta_graph(tf_model_fname + '.meta')

    graph = tf.get_default_graph()
    tf_model = graph.get_tensor_by_name(u'vertices:0')

    with tf.Session() as session:
        saver.restore(session, tf_model_fname)

        # Workaround as existing tf.Variable cannot be retrieved back with tf.get_variable
        # tf_v_template = [x for x in tf.trainable_variables() if 'v_template' in x.name][0]
        tf_trans = [x for x in tf.trainable_variables()
                    if 'trans' in x.name][0]
        tf_rot = [x for x in tf.trainable_variables() if 'rot' in x.name][0]
        tf_pose = [x for x in tf.trainable_variables() if 'pose' in x.name][0]
        tf_shape = [x for x in tf.trainable_variables()
                    if 'shape' in x.name][0]
        tf_exp = [x for x in tf.trainable_variables() if 'exp' in x.name][0]

        lmks = tf_get_model_lmks(tf_model, template_mesh, lmk_face_idx,
                                 lmk_b_coords)
        lmk_dist = tf.reduce_sum(
            tf.square(1000 * tf.subtract(lmks, target_3d_lmks)))
        neck_pose_reg = tf.reduce_sum(tf.square(tf_pose[:3]))
        jaw_pose_reg = tf.reduce_sum(tf.square(tf_pose[3:6]))
        eyeballs_pose_reg = tf.reduce_sum(tf.square(tf_pose[6:]))
        shape_reg = tf.reduce_sum(tf.square(tf_shape))
        exp_reg = tf.reduce_sum(tf.square(tf_exp))

        # Optimize global transformation first
        vars = [tf_trans, tf_rot]
        loss = weights['lmk'] * lmk_dist
        optimizer = scipy_pt(loss=loss,
                             var_list=vars,
                             method='L-BFGS-B',
                             options={
                                 'disp': 1,
                                 'ftol': 5e-6
                             })
        print('Optimize rigid transformation')
        optimizer.minimize(session)

        # Optimize for the model parameters
        vars = [tf_trans, tf_rot, tf_pose, tf_shape, tf_exp]
        loss = weights['lmk'] * lmk_dist + weights['shape'] * shape_reg + weights['expr'] * exp_reg + \
               weights['neck_pose'] * neck_pose_reg + weights['jaw_pose'] * jaw_pose_reg + weights['eyeballs_pose'] * eyeballs_pose_reg

        optimizer = scipy_pt(loss=loss,
                             var_list=vars,
                             method='L-BFGS-B',
                             options={
                                 'disp': 1,
                                 'ftol': 5e-6
                             })
        print('Optimize model parameters')
        optimizer.minimize(session)

        print('Fitting done')

        if show_fitting:
            # Visualize landmark fitting
            mv = MeshViewer()
            mv.set_static_meshes(
                create_lmk_spheres(target_3d_lmks, 0.001, [255.0, 0.0, 0.0]))
            mv.set_dynamic_meshes(
                [Mesh(session.run(tf_model), template_mesh.f)] +
                create_lmk_spheres(session.run(lmks), 0.001,
                                   [0.0, 0.0, 255.0]),
                blocking=True)
            six.moves.input('Press key to continue')

        return Mesh(session.run(tf_model), template_mesh.f)
Beispiel #19
0
        def _fit_sentence(src_dir, dst_dir, prm_dir, last_speaker):
            _anchor = os.path.join(dst_dir, "_anchor")
            if os.path.exists(_anchor):
                print("- Skip " + src_dir)
                return
            if not os.path.exists(src_dir):
                print("- Failed to find " + src_dir)
                return
            if not os.path.exists(dst_dir): os.makedirs(dst_dir)
            if not os.path.exists(prm_dir): os.makedirs(prm_dir)

            ply_files = []
            for root, _, files in os.walk(src_dir):
                for f in files:
                    if os.path.splitext(f)[1] == ".ply":
                        ply_files.append(os.path.join(root, f))
            ply_files = sorted(ply_files)

            # get shared
            src_mesh = Mesh(filename=ply_files[0])
            session.run(tf.assign(tf_src, src_mesh.v))

            speaker = os.path.basename(os.path.dirname(src_dir))

            if last_speaker != speaker:
                print("- clear speaker information")
                _zero_state("trans", "rot", "pose", "shape", "exp")
            else:
                _zero_state("exp")

            stt_dir = os.path.join(os.path.dirname(dst_dir), "state")
            if os.path.exists(stt_dir):
                state_dict = dict(
                    trans  = np.load(os.path.join(stt_dir, "trans.npy")),
                    rot    = np.load(os.path.join(stt_dir, "rot.npy")),
                    pose   = np.load(os.path.join(stt_dir, "pose.npy")),
                    shape  = np.load(os.path.join(stt_dir, "shape.npy")),
                )
                _load_state(state_dict)

                fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
                fitting_mesh.write_ply(os.path.join(stt_dir, "zero.ply"))
            fit_zero_dir = os.path.join(os.path.dirname(os.path.dirname(dst_dir)), "zero_exp")
            if not os.path.exists(fit_zero_dir): os.makedirs(fit_zero_dir)

            print("- " + speaker + " " + os.path.basename(src_dir))
            print("  -> fit shared parameters...")
            optim_shared_rigid.minimize(session)
            optim_shared_all.minimize(session)

            state_dict = _save_state("exp", set_zero=True)

            fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
            fitting_mesh.write_ply(os.path.join(fit_zero_dir, "{}.ply".format(speaker)))

            _load_state(state_dict)

            return

            if not os.path.exists(stt_dir): os.makedirs(stt_dir)
            np.save(os.path.join(stt_dir, "trans.npy"), tf_trans.eval(), allow_pickle=False)
            np.save(os.path.join(stt_dir, "rot.npy"),   tf_rot.eval(),   allow_pickle=False)
            np.save(os.path.join(stt_dir, "pose.npy"),  tf_pose.eval(),  allow_pickle=False)
            np.save(os.path.join(stt_dir, "shape.npy"), tf_shape.eval(), allow_pickle=False)

            progress = tqdm(ply_files)
            for src_fname in progress:
                frame = os.path.basename(src_fname)
                progress.set_description("  -> " + frame)
                dst_fname = os.path.join(dst_dir, frame)
                # param filename
                prm_fname = os.path.join(prm_dir, frame)
                exp_fname = os.path.splitext(prm_fname)[0] + '_exp.npy'
                idn_fname = os.path.splitext(prm_fname)[0] + '_idn.npy'

                src_mesh = Mesh(filename=src_fname)
                session.run(tf.assign(tf_src, src_mesh.v))

                optim_seq.minimize(session)

                # save expr
                np.save(exp_fname, tf_exp.eval())
                np.save(idn_fname, tf_shape.eval())

                # state_dict = _save_state("trans", "rot", "pose", "shape", set_zero=True)

                # save mesh
                fitting_mesh = Mesh(session.run(tf_model), src_mesh.f)
                fitting_mesh.write_ply(dst_fname)

                # _load_state(state_dict)
                # print(tf_shape.eval())

                if showing:
                    g_mv.set_static_meshes([fitting_mesh])

            os.system("touch {}".format(_anchor))
            return speaker
Beispiel #20
0
    def process(self):
        train_data, val_data, test_data = [], [], []
        train_vertices = []
        for key in self.data_file:
            for idx, data_file in tqdm(enumerate(self.data_file[key])):
                if key == 'lgtd':
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_m24 = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_m24.v),
                                edge_index=edge_index)
                elif key == 'lgtddx':
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_m24 = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_m24.v),
                                edge_index=edge_index,
                                label=torch.Tensor(data_file[2]))
                elif key == 'lgtdvc':
                    #print(data_file)
                    mesh = Mesh(filename=data_file[0])
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    mesh_fu = Mesh(filename=data_file[1])
                    data = Data(x=mesh_verts,
                                y=torch.Tensor(mesh_fu.v),
                                edge_index=edge_index,
                                label=torch.Tensor(data_file[2]),
                                period=torch.Tensor(data_file[3]))
                else:
                    mesh = Mesh(filename=data_file)
                    mesh_verts = torch.Tensor(mesh.v)
                    adjacency = get_vert_connectivity(mesh.v, mesh.f).tocoo()
                    edge_index = torch.Tensor(
                        np.vstack((adjacency.row, adjacency.col)))
                    if key == 'ad':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    label=torch.Tensor([1, 0]),
                                    edge_index=edge_index)
                    elif key == 'cn':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    label=torch.Tensor([0, 1]),
                                    edge_index=edge_index)
                    elif key == 'all':
                        data = Data(x=mesh_verts,
                                    y=mesh_verts,
                                    edge_index=edge_index)

                if idx % 100 < 10:
                    test_data.append(data)
                    #print(data.period)
                elif idx % 100 < 20:
                    val_data.append(data)
                else:
                    train_data.append(data)
                    train_vertices.append(mesh.v)
        print(len(train_data), len(val_data), len(test_data))

        mean_train = torch.Tensor(np.mean(train_vertices, axis=0))
        std_train = torch.Tensor(np.std(train_vertices, axis=0))
        norm_dict = {'mean': mean_train, 'std': std_train}
        if self.pre_transform is not None:
            if hasattr(self.pre_transform, 'mean') and hasattr(
                    self.pre_transform, 'std'):
                if self.pre_tranform.mean is None:
                    self.pre_tranform.mean = mean_train
                if self.pre_transform.std is None:
                    self.pre_tranform.std = std_train
            train_data = [self.pre_transform(td) for td in train_data]
            val_data = [self.pre_transform(td) for td in val_data]
            test_data = [self.pre_transform(td) for td in test_data]

        torch.save(self.collate(train_data), self.processed_paths[0])
        torch.save(self.collate(val_data), self.processed_paths[1])
        torch.save(self.collate(test_data), self.processed_paths[2])
        torch.save(norm_dict, self.processed_paths[3])
Beispiel #21
0
def vis_results(dorig,
                coarse_net,
                refine_net,
                rh_model,
                show_gen=True,
                show_rec=True,
                save=False,
                save_dir=None):

    with torch.no_grad():
        imw, imh = 400, 1000
        cols = len(dorig['bps_object'])
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        if show_rec:
            mvs = MeshViewers(window_width=imw * cols,
                              window_height=imh,
                              shape=[3, cols],
                              keepalive=True)
            drec_cnet = coarse_net(**dorig)
            verts_rh_rec_cnet = rh_model(**drec_cnet).vertices

            _, h2o, _ = point2point_signed(verts_rh_rec_cnet,
                                           dorig['verts_object'])

            drec_cnet['trans_rhand_f'] = drec_cnet['transl']
            drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
                drec_cnet['global_orient']).view(-1, 3, 3)
            drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(
                drec_cnet['hand_pose']).view(-1, 15, 3, 3)
            drec_cnet['verts_object'] = dorig['verts_object']
            drec_cnet['h2o_dist'] = h2o.abs()

            drec_rnet = refine_net(**drec_cnet)
            verts_rh_rec_rnet = rh_model(**drec_rnet).vertices

            for cId in range(0, len(dorig['bps_object'])):
                try:
                    from copy import deepcopy
                    meshes = deepcopy(dorig['mesh_object'])
                    obj_mesh = [meshes[cId]]
                except:
                    obj_mesh = points_to_spheres(points=to_cpu(
                        dorig['verts_object'][cId]),
                                                 radius=0.002,
                                                 vc=name_to_rgb['green'])

                hand_mesh_orig = Mesh(v=to_cpu(dorig['verts_rhand'][cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['blue'])
                hand_mesh_rec_cnet = Mesh(v=to_cpu(verts_rh_rec_cnet[cId]),
                                          f=rh_model.faces,
                                          vc=name_to_rgb['green'])
                hand_mesh_rec_rnet = Mesh(v=to_cpu(verts_rh_rec_rnet[cId]),
                                          f=rh_model.faces,
                                          vc=name_to_rgb['red'])

                if 'rotmat' in dorig:
                    rotmat = dorig['rotmat'][cId].T
                    obj_mesh = [obj_mesh[0].rotate_vertices(rotmat)]
                    hand_mesh_orig.rotate_vertices(rotmat)
                    hand_mesh_rec_cnet.rotate_vertices(rotmat)
                    hand_mesh_rec_rnet.rotate_vertices(rotmat)

                mvs[0][cId].set_static_meshes([hand_mesh_orig] + obj_mesh,
                                              blocking=True)
                mvs[1][cId].set_static_meshes([hand_mesh_rec_cnet] + obj_mesh,
                                              blocking=True)
                mvs[2][cId].set_static_meshes([hand_mesh_rec_rnet] + obj_mesh,
                                              blocking=True)

                if save:
                    save_path = os.path.join(save_dir, str(cId))
                    makepath(save_path)
                    hand_mesh_rec_rnet.write_ply(filename=save_path +
                                                 '/rh_mesh_gen_%d.ply' % cId)
                    obj_mesh[0].write_ply(filename=save_path +
                                          '/obj_mesh_%d.ply' % cId)

        if show_gen:
            mvs = MeshViewers(window_width=imw * cols,
                              window_height=imh,
                              shape=[2, cols],
                              keepalive=True)

            drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
            verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

            _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                           dorig['verts_object'].to(device))

            drec_cnet['trans_rhand_f'] = drec_cnet['transl']
            drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
                drec_cnet['global_orient']).view(-1, 3, 3)
            drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(
                drec_cnet['hand_pose']).view(-1, 15, 3, 3)
            drec_cnet['verts_object'] = dorig['verts_object'].to(device)
            drec_cnet['h2o_dist'] = h2o.abs()

            drec_rnet = refine_net(**drec_cnet)
            verts_rh_gen_rnet = rh_model(**drec_rnet).vertices

            for cId in range(0, len(dorig['bps_object'])):
                try:
                    from copy import deepcopy
                    meshes = deepcopy(dorig['mesh_object'])
                    obj_mesh = [meshes[cId]]
                except:
                    obj_mesh = points_to_spheres(to_cpu(
                        dorig['verts_object'][cId]),
                                                 radius=0.002,
                                                 vc=name_to_rgb['green'])

                hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]),
                                          f=rh_model.faces,
                                          vc=name_to_rgb['pink'])
                hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]),
                                          f=rh_model.faces,
                                          vc=name_to_rgb['gray'])

                if 'rotmat' in dorig:
                    rotmat = dorig['rotmat'][cId].T
                    obj_mesh = [obj_mesh[0].rotate_vertices(rotmat)]
                    hand_mesh_gen_cnet.rotate_vertices(rotmat)
                    hand_mesh_gen_rnet.rotate_vertices(rotmat)

                mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh,
                                              blocking=True)
                mvs[1][cId].set_static_meshes([hand_mesh_gen_rnet] + obj_mesh,
                                              blocking=True)

                if save:
                    save_path = os.path.join(save_dir, str(cId))
                    makepath(save_path)
                    hand_mesh_gen_rnet.write_ply(filename=save_path +
                                                 '/rh_mesh_gen_%d.ply' % cId)
                    obj_mesh[0].write_ply(filename=save_path +
                                          '/obj_mesh_%d.ply' % cId)
Beispiel #22
0
# args = parser.parse_args()
def frontalize(vertices, canonical_vertices):
    # mesh = Mesh(filename="/home/user/3dfaceRe/center-loss_conv/scripts/template_fwh.obj")
    # canonical_vertices = mesh.v
    #canonical_vertices = np.load('Data/uv-data/canonical_vertices.npy')
    vertices_homo = np.hstack((vertices, np.ones([vertices.shape[0],
                                                  1])))  #n x 4
    P = np.linalg.lstsq(vertices_homo,
                        canonical_vertices)[0].T  # Affine matrix. 3 x 4
    front_vertices = vertices_homo.dot(P.T)

    return front_vertices


template = Mesh(filename='./data/template.obj')

nVal = 100  # args.num_valid
root_dir = './'  # args.root_dir
dataset = 'data'  # args.dataset
name = 'sliced'

data = os.path.join(root_dir, dataset, 'Processed', name)

train = np.load(data + '/train.npy')
# train = torch.tensor(train.astype('float32'))
# mean = torch.mean(train, dim=0)
# std = torch.std(train, dim=0)
# torch.save(mean, './data/Processed/sliced/mean.tch')
# torch.save(std, './data/Processed/sliced/std.tch')
Beispiel #23
0
def fit_lmk2d(target_img, target_2d_lmks, template_fname, tf_model_fname,
              lmk_face_idx, lmk_b_coords, weights):
    '''
    Fit FLAME to 2D landmarks
    :param target_2d_lmks:      target 2D landmarks provided as (num_lmks x 3) matrix
    :param template_fname:      template mesh in FLAME topology (only the face information are used)
    :param tf_model_fname:      saved Tensorflow FLAME model
    :param lmk_face_idx:        face indices of the landmark embedding in the FLAME topology
    :param lmk_b_coords:        barycentric coordinates of the landmark embedding in the FLAME topology
                                (i.e. weighting of the three vertices for the trinagle, the landmark is embedded in
    :param weights:             weights of the individual objective functions
    :return: a mesh with the fitting results
    '''

    template_mesh = Mesh(filename=template_fname)
    saver = tf.train.import_meta_graph(tf_model_fname + '.meta')

    graph = tf.get_default_graph()
    tf_model = graph.get_tensor_by_name(u'vertices:0')

    with tf.Session() as session:
        saver.restore(session, tf_model_fname)

        # Workaround as existing tf.Variable cannot be retrieved back with tf.get_variable
        # tf_v_template = [x for x in tf.trainable_variables() if 'v_template' in x.name][0]
        tf_trans = [x for x in tf.trainable_variables()
                    if 'trans' in x.name][0]
        tf_rot = [x for x in tf.trainable_variables() if 'rot' in x.name][0]
        tf_pose = [x for x in tf.trainable_variables() if 'pose' in x.name][0]
        tf_shape = [x for x in tf.trainable_variables()
                    if 'shape' in x.name][0]
        tf_exp = [x for x in tf.trainable_variables() if 'exp' in x.name][0]

        # Mirror landmark y-coordinates
        target_2d_lmks[:, 1] = target_img.shape[0] - target_2d_lmks[:, 1]

        lmks_3d = tf_get_model_lmks(tf_model, template_mesh, lmk_face_idx,
                                    lmk_b_coords)

        s2d = np.mean(
            np.linalg.norm(target_2d_lmks - np.mean(target_2d_lmks, axis=0),
                           axis=1))
        s3d = tf.reduce_mean(
            tf.sqrt(
                tf.reduce_sum(
                    tf.square(lmks_3d -
                              tf.reduce_mean(lmks_3d, axis=0))[:, :2],
                    axis=1)))
        tf_scale = tf.Variable(s2d / s3d, dtype=lmks_3d.dtype)

        # trans = 0.5*np.array((target_img.shape[0], target_img.shape[1]))/tf_scale
        # trans = 0.5 * s3d * np.array((target_img.shape[0], target_img.shape[1])) / s2d
        lmks_proj_2d = tf_project_points(lmks_3d, tf_scale, np.zeros(2))

        factor = max(
            max(target_2d_lmks[:, 0]) - min(target_2d_lmks[:, 0]),
            max(target_2d_lmks[:, 1]) - min(target_2d_lmks[:, 1]))
        lmk_dist = weights['lmk'] * tf.reduce_sum(
            tf.square(tf.subtract(lmks_proj_2d, target_2d_lmks))) / (factor**2)
        neck_pose_reg = weights['neck_pose'] * tf.reduce_sum(
            tf.square(tf_pose[:3]))
        jaw_pose_reg = weights['jaw_pose'] * tf.reduce_sum(
            tf.square(tf_pose[3:6]))
        eyeballs_pose_reg = weights['eyeballs_pose'] * tf.reduce_sum(
            tf.square(tf_pose[6:]))
        shape_reg = weights['shape'] * tf.reduce_sum(tf.square(tf_shape))
        exp_reg = weights['expr'] * tf.reduce_sum(tf.square(tf_exp))

        session.run(tf.global_variables_initializer())

        def on_step(verts,
                    scale,
                    faces,
                    target_img,
                    target_lmks,
                    opt_lmks,
                    lmk_dist=0.0,
                    shape_reg=0.0,
                    exp_reg=0.0,
                    neck_pose_reg=0.0,
                    jaw_pose_reg=0.0,
                    eyeballs_pose_reg=0.0):
            import cv2
            import sys
            import numpy as np
            from psbody.mesh import Mesh
            from utils.render_mesh import render_mesh

            if lmk_dist > 0.0 or shape_reg > 0.0 or exp_reg > 0.0 or neck_pose_reg > 0.0 or jaw_pose_reg > 0.0 or eyeballs_pose_reg > 0.0:
                print(
                    'lmk_dist: %f, shape_reg: %f, exp_reg: %f, neck_pose_reg: %f, jaw_pose_reg: %f, eyeballs_pose_reg: %f'
                    % (lmk_dist, shape_reg, exp_reg, neck_pose_reg,
                       jaw_pose_reg, eyeballs_pose_reg))

            plt_target_lmks = target_lmks.copy()
            plt_target_lmks[:, 1] = target_img.shape[0] - plt_target_lmks[:, 1]
            for (x, y) in plt_target_lmks:
                cv2.circle(target_img, (int(x), int(y)), 4, (0, 0, 255), -1)

            plt_opt_lmks = opt_lmks.copy()
            plt_opt_lmks[:, 1] = target_img.shape[0] - plt_opt_lmks[:, 1]
            for (x, y) in plt_opt_lmks:
                cv2.circle(target_img, (int(x), int(y)), 4, (255, 0, 0), -1)

            if sys.version_info >= (3, 0):
                rendered_img = render_mesh(Mesh(scale * verts, faces),
                                           height=target_img.shape[0],
                                           width=target_img.shape[1])
                for (x, y) in plt_opt_lmks:
                    cv2.circle(rendered_img, (int(x), int(y)), 4, (255, 0, 0),
                               -1)
                target_img = np.hstack((target_img, rendered_img))

            cv2.imshow('img', target_img)
            cv2.waitKey(10)

        print('Optimize rigid transformation')
        vars = [tf_scale, tf_trans, tf_rot]
        loss = lmk_dist
        optimizer = scipy_pt(loss=loss,
                             var_list=vars,
                             method='L-BFGS-B',
                             options={
                                 'disp': 1,
                                 'ftol': 5e-6
                             })
        optimizer.minimize(session,
                           fetches=[
                               tf_model, tf_scale,
                               tf.constant(template_mesh.f),
                               tf.constant(target_img),
                               tf.constant(target_2d_lmks), lmks_proj_2d
                           ],
                           loss_callback=on_step)

        print('Optimize model parameters')
        vars = [tf_scale, tf_trans[:2], tf_rot, tf_pose, tf_shape, tf_exp]
        loss = lmk_dist + shape_reg + exp_reg + neck_pose_reg + jaw_pose_reg + eyeballs_pose_reg

        optimizer = scipy_pt(loss=loss,
                             var_list=vars,
                             method='L-BFGS-B',
                             options={
                                 'disp': 0,
                                 'ftol': 1e-7
                             })
        optimizer.minimize(session,
                           fetches=[
                               tf_model, tf_scale,
                               tf.constant(template_mesh.f),
                               tf.constant(target_img),
                               tf.constant(target_2d_lmks), lmks_proj_2d,
                               lmk_dist, shape_reg, exp_reg, neck_pose_reg,
                               jaw_pose_reg, eyeballs_pose_reg
                           ],
                           loss_callback=on_step)

        print('Fitting done')
        np_verts, np_scale = session.run([tf_model, tf_scale])
        return Mesh(np_verts, template_mesh.f), np_scale
Beispiel #24
0
            (pack(results[i].dot(ch.concatenate((self.J[i, :], [0])))))
            for i in range(len(results))
        ]
        result = ch.dstack(results2)

        return result, results_global

    def compute_r(self):
        return self.v.r

    def compute_dr_wrt(self, wrt):
        if wrt is not self.trans and wrt is not self.betas and wrt is not self.pose and wrt is not self.v_personal and wrt is not self.v_template:
            return None

        return self.v.dr_wrt(wrt)


if __name__ == '__main__':
    from utils.smpl_paths import SmplPaths

    dp = SmplPaths(gender='neutral')

    smpl = Smpl(dp.get_smpl_file())

    from psbody.mesh.meshviewer import MeshViewer
    from psbody.mesh import Mesh
    mv = MeshViewer()
    mv.set_static_meshes([Mesh(smpl.r, smpl.f)])

    raw_input("Press Enter to continue...")
    ## This file contains correspondances between garment vertices and smpl body
    fts_file = 'assets/garment_fts.pkl'
    vert_indices, fts = pkl.load(open(fts_file, "rb"), encoding="latin1")
    fts['naked'] = ft

    ## Choose any garmet type as source
    garment_type = 'TShirtNoCoat'
    index = np.random.randint(0, len(gar_dict[garment_type]))   ## Randomly pick from the digital wardrobe
    path = split(gar_dict[garment_type][index])[0]


    garment_org_body_unposed = load_smpl_from_file(join(path, 'registration.pkl'))
    garment_org_body_unposed.pose[:] = 0
    garment_org_body_unposed.trans[:] = 0
    garment_org_body_unposed = Mesh(garment_org_body_unposed.v, garment_org_body_unposed.f)

    garment_unposed = Mesh(filename=join(path, garment_type + '.obj'))
    garment_tex = join(path, 'multi_tex.jpg')

    ## Generate random SMPL body (Feel free to set up ur own smpl) as target subject
    smpl.pose[:] = np.random.randn(72) *0.05
    smpl.betas[:] = np.random.randn(10) *0.01
    smpl.trans[:] = 0
    tgt_body = Mesh(smpl.r, smpl.f)

    vert_inds = vert_indices[garment_type]
    garment_unposed.set_texture_image(garment_tex)

    new_garment = dress(smpl, garment_org_body_unposed, garment_unposed, vert_inds, garment_tex)
def fit_single_frame(
        img,
        keypoints,
        init_trans,
        scan,
        scene_name,
        body_model,
        camera,
        joint_weights,
        body_pose_prior,
        jaw_prior,
        left_hand_prior,
        right_hand_prior,
        shape_prior,
        expr_prior,
        angle_prior,
        result_fn='out.pkl',
        mesh_fn='out.obj',
        body_scene_rendering_fn='body_scene.png',
        out_img_fn='overlay.png',
        loss_type='smplify',
        use_cuda=True,
        init_joints_idxs=(9, 12, 2, 5),
        use_face=True,
        use_hands=True,
        data_weights=None,
        body_pose_prior_weights=None,
        hand_pose_prior_weights=None,
        jaw_pose_prior_weights=None,
        shape_weights=None,
        expr_weights=None,
        hand_joints_weights=None,
        face_joints_weights=None,
        depth_loss_weight=1e2,
        interpenetration=True,
        coll_loss_weights=None,
        df_cone_height=0.5,
        penalize_outside=True,
        max_collisions=8,
        point2plane=False,
        part_segm_fn='',
        focal_length_x=5000.,
        focal_length_y=5000.,
        side_view_thsh=25.,
        rho=100,
        vposer_latent_dim=32,
        vposer_ckpt='',
        use_joints_conf=False,
        interactive=True,
        visualize=False,
        save_meshes=True,
        degrees=None,
        batch_size=1,
        dtype=torch.float32,
        ign_part_pairs=None,
        left_shoulder_idx=2,
        right_shoulder_idx=5,
        ####################
        ### PROX
        render_results=True,
        camera_mode='moving',
        ## Depth
        s2m=False,
        s2m_weights=None,
        m2s=False,
        m2s_weights=None,
        rho_s2m=1,
        rho_m2s=1,
        init_mode=None,
        trans_opt_stages=None,
        viz_mode='mv',
        #penetration
        sdf_penetration=False,
        sdf_penetration_weights=0.0,
        sdf_dir=None,
        cam2world_dir=None,
        #contact
        contact=False,
        rho_contact=1.0,
        contact_loss_weights=None,
        contact_angle=15,
        contact_body_parts=None,
        body_segments_dir=None,
        load_scene=False,
        scene_dir=None,
        **kwargs):
    assert batch_size == 1, 'PyTorch L-BFGS only supports batch_size == 1'
    body_model.reset_params()
    body_model.transl.requires_grad = True

    device = torch.device('cuda') if use_cuda else torch.device('cpu')

    if visualize:
        pil_img.fromarray((img * 255).astype(np.uint8)).show()

    if degrees is None:
        degrees = [0, 90, 180, 270]

    if data_weights is None:
        data_weights = [
            1,
        ] * 5

    if body_pose_prior_weights is None:
        body_pose_prior_weights = [4.04 * 1e2, 4.04 * 1e2, 57.4, 4.78]

    msg = ('Number of Body pose prior weights {}'.format(
        len(body_pose_prior_weights)) +
           ' does not match the number of data term weights {}'.format(
               len(data_weights)))
    assert (len(data_weights) == len(body_pose_prior_weights)), msg

    if use_hands:
        if hand_pose_prior_weights is None:
            hand_pose_prior_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of hand pose prior weights')
        assert (
            len(hand_pose_prior_weights) == len(body_pose_prior_weights)), msg
        if hand_joints_weights is None:
            hand_joints_weights = [0.0, 0.0, 0.0, 1.0]
            msg = ('Number of Body pose prior weights does not match the' +
                   ' number of hand joint distance weights')
            assert (
                len(hand_joints_weights) == len(body_pose_prior_weights)), msg

    if shape_weights is None:
        shape_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
    msg = ('Number of Body pose prior weights = {} does not match the' +
           ' number of Shape prior weights = {}')
    assert (len(shape_weights) == len(body_pose_prior_weights)), msg.format(
        len(shape_weights), len(body_pose_prior_weights))

    if use_face:
        if jaw_pose_prior_weights is None:
            jaw_pose_prior_weights = [[x] * 3 for x in shape_weights]
        else:
            jaw_pose_prior_weights = map(lambda x: map(float, x.split(',')),
                                         jaw_pose_prior_weights)
            jaw_pose_prior_weights = [list(w) for w in jaw_pose_prior_weights]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of jaw pose prior weights')
        assert (
            len(jaw_pose_prior_weights) == len(body_pose_prior_weights)), msg

        if expr_weights is None:
            expr_weights = [1e2, 5 * 1e1, 1e1, .5 * 1e1]
        msg = ('Number of Body pose prior weights = {} does not match the' +
               ' number of Expression prior weights = {}')
        assert (len(expr_weights) == len(body_pose_prior_weights)), msg.format(
            len(body_pose_prior_weights), len(expr_weights))

        if face_joints_weights is None:
            face_joints_weights = [0.0, 0.0, 0.0, 1.0]
        msg = ('Number of Body pose prior weights does not match the' +
               ' number of face joint distance weights')
        assert (len(face_joints_weights) == len(body_pose_prior_weights)), msg

    if coll_loss_weights is None:
        coll_loss_weights = [0.0] * len(body_pose_prior_weights)
    msg = ('Number of Body pose prior weights does not match the' +
           ' number of collision loss weights')
    assert (len(coll_loss_weights) == len(body_pose_prior_weights)), msg

    use_vposer = kwargs.get('use_vposer', True)
    vposer, pose_embedding = [
        None,
    ] * 2
    if use_vposer:
        pose_embedding = torch.zeros([batch_size, 32],
                                     dtype=dtype,
                                     device=device,
                                     requires_grad=True)

        vposer_ckpt = osp.expandvars(vposer_ckpt)
        vposer, _ = load_vposer(vposer_ckpt, vp_model='snapshot')
        vposer = vposer.to(device=device)
        vposer.eval()

    if use_vposer:
        body_mean_pose = torch.zeros([batch_size, vposer_latent_dim],
                                     dtype=dtype)
    else:
        body_mean_pose = body_pose_prior.get_mean().detach().cpu()

    keypoint_data = torch.tensor(keypoints, dtype=dtype)
    gt_joints = keypoint_data[:, :, :2]
    if use_joints_conf:
        joints_conf = keypoint_data[:, :, 2].reshape(1, -1)

    # Transfer the data to the correct device
    gt_joints = gt_joints.to(device=device, dtype=dtype)
    if use_joints_conf:
        joints_conf = joints_conf.to(device=device, dtype=dtype)

    scan_tensor = None
    if scan is not None:
        scan_tensor = torch.tensor(scan.get('points'),
                                   device=device,
                                   dtype=dtype).unsqueeze(0)

    # load pre-computed signed distance field
    sdf = None
    sdf_normals = None
    grid_min = None
    grid_max = None
    voxel_size = None
    if sdf_penetration:
        with open(osp.join(sdf_dir, scene_name + '.json'), 'r') as f:
            sdf_data = json.load(f)
            grid_min = torch.tensor(np.array(sdf_data['min']),
                                    dtype=dtype,
                                    device=device)
            grid_max = torch.tensor(np.array(sdf_data['max']),
                                    dtype=dtype,
                                    device=device)
            grid_dim = sdf_data['dim']
        voxel_size = (grid_max - grid_min) / grid_dim
        sdf = np.load(osp.join(sdf_dir, scene_name + '_sdf.npy')).reshape(
            grid_dim, grid_dim, grid_dim)
        sdf = torch.tensor(sdf, dtype=dtype, device=device)
        if osp.exists(osp.join(sdf_dir, scene_name + '_normals.npy')):
            sdf_normals = np.load(
                osp.join(sdf_dir, scene_name + '_normals.npy')).reshape(
                    grid_dim, grid_dim, grid_dim, 3)
            sdf_normals = torch.tensor(sdf_normals, dtype=dtype, device=device)
        else:
            print("Normals not found...")

    with open(os.path.join(cam2world_dir, scene_name + '.json'), 'r') as f:
        cam2world = np.array(json.load(f))
        R = torch.tensor(cam2world[:3, :3].reshape(3, 3),
                         dtype=dtype,
                         device=device)
        t = torch.tensor(cam2world[:3, 3].reshape(1, 3),
                         dtype=dtype,
                         device=device)

    # Create the search tree
    search_tree = None
    pen_distance = None
    filter_faces = None
    if interpenetration:
        from mesh_intersection.bvh_search_tree import BVH
        import mesh_intersection.loss as collisions_loss
        from mesh_intersection.filter_faces import FilterFaces

        assert use_cuda, 'Interpenetration term can only be used with CUDA'
        assert torch.cuda.is_available(), \
            'No CUDA Device! Interpenetration term can only be used' + \
            ' with CUDA'

        search_tree = BVH(max_collisions=max_collisions)

        pen_distance = \
            collisions_loss.DistanceFieldPenetrationLoss(
                sigma=df_cone_height, point2plane=point2plane,
                vectorized=True, penalize_outside=penalize_outside)

        if part_segm_fn:
            # Read the part segmentation
            part_segm_fn = os.path.expandvars(part_segm_fn)
            with open(part_segm_fn, 'rb') as faces_parents_file:
                face_segm_data = pickle.load(faces_parents_file,
                                             encoding='latin1')
            faces_segm = face_segm_data['segm']
            faces_parents = face_segm_data['parents']
            # Create the module used to filter invalid collision pairs
            filter_faces = FilterFaces(
                faces_segm=faces_segm,
                faces_parents=faces_parents,
                ign_part_pairs=ign_part_pairs).to(device=device)

    # load vertix ids of contact parts
    contact_verts_ids = ftov = None
    if contact:
        contact_verts_ids = []
        for part in contact_body_parts:
            with open(os.path.join(body_segments_dir, part + '.json'),
                      'r') as f:
                data = json.load(f)
                contact_verts_ids.append(list(set(data["verts_ind"])))
        contact_verts_ids = np.concatenate(contact_verts_ids)

        vertices = body_model(return_verts=True,
                              body_pose=torch.zeros((batch_size, 63),
                                                    dtype=dtype,
                                                    device=device)).vertices
        vertices_np = vertices.detach().cpu().numpy().squeeze()
        body_faces_np = body_model.faces_tensor.detach().cpu().numpy().reshape(
            -1, 3)
        m = Mesh(v=vertices_np, f=body_faces_np)
        ftov = m.faces_by_vertex(as_sparse_matrix=True)

        ftov = sparse.coo_matrix(ftov)
        indices = torch.LongTensor(np.vstack((ftov.row, ftov.col))).to(device)
        values = torch.FloatTensor(ftov.data).to(device)
        shape = ftov.shape
        ftov = torch.sparse.FloatTensor(indices, values, torch.Size(shape))

    # Read the scene scan if any
    scene_v = scene_vn = scene_f = None
    if scene_name is not None:
        if load_scene:
            scene = Mesh(filename=os.path.join(scene_dir, scene_name + '.ply'))

            scene.vn = scene.estimate_vertex_normals()

            scene_v = torch.tensor(scene.v[np.newaxis, :],
                                   dtype=dtype,
                                   device=device).contiguous()
            scene_vn = torch.tensor(scene.vn[np.newaxis, :],
                                    dtype=dtype,
                                    device=device)
            scene_f = torch.tensor(scene.f.astype(int)[np.newaxis, :],
                                   dtype=torch.long,
                                   device=device)

    # Weights used for the pose prior and the shape prior
    opt_weights_dict = {
        'data_weight': data_weights,
        'body_pose_weight': body_pose_prior_weights,
        'shape_weight': shape_weights
    }
    if use_face:
        opt_weights_dict['face_weight'] = face_joints_weights
        opt_weights_dict['expr_prior_weight'] = expr_weights
        opt_weights_dict['jaw_prior_weight'] = jaw_pose_prior_weights
    if use_hands:
        opt_weights_dict['hand_weight'] = hand_joints_weights
        opt_weights_dict['hand_prior_weight'] = hand_pose_prior_weights
    if interpenetration:
        opt_weights_dict['coll_loss_weight'] = coll_loss_weights
    if s2m:
        opt_weights_dict['s2m_weight'] = s2m_weights
    if m2s:
        opt_weights_dict['m2s_weight'] = m2s_weights
    if sdf_penetration:
        opt_weights_dict['sdf_penetration_weight'] = sdf_penetration_weights
    if contact:
        opt_weights_dict['contact_loss_weight'] = contact_loss_weights

    keys = opt_weights_dict.keys()
    opt_weights = [
        dict(zip(keys, vals))
        for vals in zip(*(opt_weights_dict[k] for k in keys
                          if opt_weights_dict[k] is not None))
    ]
    for weight_list in opt_weights:
        for key in weight_list:
            weight_list[key] = torch.tensor(weight_list[key],
                                            device=device,
                                            dtype=dtype)

    # load indices of the head of smpl-x model
    with open(osp.join(body_segments_dir, 'body_mask.json'), 'r') as fp:
        head_indx = np.array(json.load(fp))
    N = body_model.get_num_verts()
    body_indx = np.setdiff1d(np.arange(N), head_indx)
    head_mask = np.in1d(np.arange(N), head_indx)
    body_mask = np.in1d(np.arange(N), body_indx)

    # The indices of the joints used for the initialization of the camera
    init_joints_idxs = torch.tensor(init_joints_idxs, device=device)

    edge_indices = kwargs.get('body_tri_idxs')

    # which initialization mode to choose: similar traingles, mean of the scan or the average of both
    if init_mode == 'scan':
        init_t = init_trans
    elif init_mode == 'both':
        init_t = (init_trans.to(device) + fitting.guess_init(
            body_model,
            gt_joints,
            edge_indices,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            model_type=kwargs.get('model_type', 'smpl'),
            focal_length=focal_length_x,
            dtype=dtype)) / 2.0

    else:
        init_t = fitting.guess_init(body_model,
                                    gt_joints,
                                    edge_indices,
                                    use_vposer=use_vposer,
                                    vposer=vposer,
                                    pose_embedding=pose_embedding,
                                    model_type=kwargs.get(
                                        'model_type', 'smpl'),
                                    focal_length=focal_length_x,
                                    dtype=dtype)

    camera_loss = fitting.create_loss('camera_init',
                                      trans_estimation=init_t,
                                      init_joints_idxs=init_joints_idxs,
                                      depth_loss_weight=depth_loss_weight,
                                      camera_mode=camera_mode,
                                      dtype=dtype).to(device=device)
    camera_loss.trans_estimation[:] = init_t

    loss = fitting.create_loss(loss_type=loss_type,
                               joint_weights=joint_weights,
                               rho=rho,
                               use_joints_conf=use_joints_conf,
                               use_face=use_face,
                               use_hands=use_hands,
                               vposer=vposer,
                               pose_embedding=pose_embedding,
                               body_pose_prior=body_pose_prior,
                               shape_prior=shape_prior,
                               angle_prior=angle_prior,
                               expr_prior=expr_prior,
                               left_hand_prior=left_hand_prior,
                               right_hand_prior=right_hand_prior,
                               jaw_prior=jaw_prior,
                               interpenetration=interpenetration,
                               pen_distance=pen_distance,
                               search_tree=search_tree,
                               tri_filtering_module=filter_faces,
                               s2m=s2m,
                               m2s=m2s,
                               rho_s2m=rho_s2m,
                               rho_m2s=rho_m2s,
                               head_mask=head_mask,
                               body_mask=body_mask,
                               sdf_penetration=sdf_penetration,
                               voxel_size=voxel_size,
                               grid_min=grid_min,
                               grid_max=grid_max,
                               sdf=sdf,
                               sdf_normals=sdf_normals,
                               R=R,
                               t=t,
                               contact=contact,
                               contact_verts_ids=contact_verts_ids,
                               rho_contact=rho_contact,
                               contact_angle=contact_angle,
                               dtype=dtype,
                               **kwargs)
    loss = loss.to(device=device)

    with fitting.FittingMonitor(batch_size=batch_size,
                                visualize=visualize,
                                viz_mode=viz_mode,
                                **kwargs) as monitor:

        img = torch.tensor(img, dtype=dtype)

        H, W, _ = img.shape

        # Reset the parameters to estimate the initial translation of the
        # body model
        if camera_mode == 'moving':
            body_model.reset_params(body_pose=body_mean_pose)
            # Update the value of the translation of the camera as well as
            # the image center.
            with torch.no_grad():
                camera.translation[:] = init_t.view_as(camera.translation)
                camera.center[:] = torch.tensor([W, H], dtype=dtype) * 0.5

            # Re-enable gradient calculation for the camera translation
            camera.translation.requires_grad = True

            camera_opt_params = [camera.translation, body_model.global_orient]

        elif camera_mode == 'fixed':
            body_model.reset_params(body_pose=body_mean_pose, transl=init_t)
            camera_opt_params = [body_model.transl, body_model.global_orient]

        # If the distance between the 2D shoulders is smaller than a
        # predefined threshold then try 2 fits, the initial one and a 180
        # degree rotation
        shoulder_dist = torch.dist(gt_joints[:, left_shoulder_idx],
                                   gt_joints[:, right_shoulder_idx])
        try_both_orient = shoulder_dist.item() < side_view_thsh

        camera_optimizer, camera_create_graph = optim_factory.create_optimizer(
            camera_opt_params, **kwargs)

        # The closure passed to the optimizer
        fit_camera = monitor.create_fitting_closure(
            camera_optimizer,
            body_model,
            camera,
            gt_joints,
            camera_loss,
            create_graph=camera_create_graph,
            use_vposer=use_vposer,
            vposer=vposer,
            pose_embedding=pose_embedding,
            scan_tensor=scan_tensor,
            return_full_pose=False,
            return_verts=False)

        # Step 1: Optimize over the torso joints the camera translation
        # Initialize the computational graph by feeding the initial translation
        # of the camera and the initial pose of the body model.
        camera_init_start = time.time()
        cam_init_loss_val = monitor.run_fitting(camera_optimizer,
                                                fit_camera,
                                                camera_opt_params,
                                                body_model,
                                                use_vposer=use_vposer,
                                                pose_embedding=pose_embedding,
                                                vposer=vposer)

        if interactive:
            if use_cuda and torch.cuda.is_available():
                torch.cuda.synchronize()
            tqdm.write('Camera initialization done after {:.4f}'.format(
                time.time() - camera_init_start))
            tqdm.write('Camera initialization final loss {:.4f}'.format(
                cam_init_loss_val))

        # If the 2D detections/positions of the shoulder joints are too
        # close the rotate the body by 180 degrees and also fit to that
        # orientation
        if try_both_orient:
            body_orient = body_model.global_orient.detach().cpu().numpy()
            flipped_orient = cv2.Rodrigues(body_orient)[0].dot(
                cv2.Rodrigues(np.array([0., np.pi, 0]))[0])
            flipped_orient = cv2.Rodrigues(flipped_orient)[0].ravel()

            flipped_orient = torch.tensor(flipped_orient,
                                          dtype=dtype,
                                          device=device).unsqueeze(dim=0)
            orientations = [body_orient, flipped_orient]
        else:
            orientations = [body_model.global_orient.detach().cpu().numpy()]

        # store here the final error for both orientations,
        # and pick the orientation resulting in the lowest error
        results = []
        body_transl = body_model.transl.clone().detach()
        # Step 2: Optimize the full model
        final_loss_val = 0
        for or_idx, orient in enumerate(tqdm(orientations,
                                             desc='Orientation')):
            opt_start = time.time()

            new_params = defaultdict(transl=body_transl,
                                     global_orient=orient,
                                     body_pose=body_mean_pose)
            body_model.reset_params(**new_params)
            if use_vposer:
                with torch.no_grad():
                    pose_embedding.fill_(0)

            for opt_idx, curr_weights in enumerate(
                    tqdm(opt_weights, desc='Stage')):
                if opt_idx not in trans_opt_stages:
                    body_model.transl.requires_grad = False
                else:
                    body_model.transl.requires_grad = True
                body_params = list(body_model.parameters())

                final_params = list(
                    filter(lambda x: x.requires_grad, body_params))

                if use_vposer:
                    final_params.append(pose_embedding)

                body_optimizer, body_create_graph = optim_factory.create_optimizer(
                    final_params, **kwargs)
                body_optimizer.zero_grad()

                curr_weights['bending_prior_weight'] = (
                    3.17 * curr_weights['body_pose_weight'])
                if use_hands:
                    joint_weights[:, 25:76] = curr_weights['hand_weight']
                if use_face:
                    joint_weights[:, 76:] = curr_weights['face_weight']
                loss.reset_loss_weights(curr_weights)

                closure = monitor.create_fitting_closure(
                    body_optimizer,
                    body_model,
                    camera=camera,
                    gt_joints=gt_joints,
                    joints_conf=joints_conf,
                    joint_weights=joint_weights,
                    loss=loss,
                    create_graph=body_create_graph,
                    use_vposer=use_vposer,
                    vposer=vposer,
                    pose_embedding=pose_embedding,
                    scan_tensor=scan_tensor,
                    scene_v=scene_v,
                    scene_vn=scene_vn,
                    scene_f=scene_f,
                    ftov=ftov,
                    return_verts=True,
                    return_full_pose=True)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    stage_start = time.time()
                final_loss_val = monitor.run_fitting(
                    body_optimizer,
                    closure,
                    final_params,
                    body_model,
                    pose_embedding=pose_embedding,
                    vposer=vposer,
                    use_vposer=use_vposer)

                if interactive:
                    if use_cuda and torch.cuda.is_available():
                        torch.cuda.synchronize()
                    elapsed = time.time() - stage_start
                    if interactive:
                        tqdm.write(
                            'Stage {:03d} done after {:.4f} seconds'.format(
                                opt_idx, elapsed))

            if interactive:
                if use_cuda and torch.cuda.is_available():
                    torch.cuda.synchronize()
                elapsed = time.time() - opt_start
                tqdm.write(
                    'Body fitting Orientation {} done after {:.4f} seconds'.
                    format(or_idx, elapsed))
                tqdm.write(
                    'Body final loss val = {:.5f}'.format(final_loss_val))

            # Get the result of the fitting process
            # Store in it the errors list in order to compare multiple
            # orientations, if they exist
            result = {
                'camera_' + str(key): val.detach().cpu().numpy()
                for key, val in camera.named_parameters()
            }
            result.update({
                key: val.detach().cpu().numpy()
                for key, val in body_model.named_parameters()
            })
            if use_vposer:
                result['pose_embedding'] = pose_embedding.detach().cpu().numpy(
                )
                body_pose = vposer.decode(pose_embedding,
                                          output_type='aa').view(
                                              1, -1) if use_vposer else None
                result['body_pose'] = body_pose.detach().cpu().numpy()

            results.append({'loss': final_loss_val, 'result': result})

        with open(result_fn, 'wb') as result_file:
            if len(results) > 1:
                min_idx = (0 if results[0]['loss'] < results[1]['loss'] else 1)
            else:
                min_idx = 0
            pickle.dump(results[min_idx]['result'], result_file, protocol=2)

    if save_meshes or visualize:
        body_pose = vposer.decode(pose_embedding, output_type='aa').view(
            1, -1) if use_vposer else None

        model_type = kwargs.get('model_type', 'smpl')
        append_wrists = model_type == 'smpl' and use_vposer
        if append_wrists:
            wrist_pose = torch.zeros([body_pose.shape[0], 6],
                                     dtype=body_pose.dtype,
                                     device=body_pose.device)
            body_pose = torch.cat([body_pose, wrist_pose], dim=1)

        model_output = body_model(return_verts=True, body_pose=body_pose)
        vertices = model_output.vertices.detach().cpu().numpy().squeeze()

        import trimesh

        out_mesh = trimesh.Trimesh(vertices, body_model.faces, process=False)
        out_mesh.export(mesh_fn)

    if render_results:
        import pyrender

        # common
        H, W = 1080, 1920
        camera_center = np.array([951.30, 536.77])
        camera_pose = np.eye(4)
        camera_pose = np.array([1.0, -1.0, -1.0, 1.0]).reshape(-1,
                                                               1) * camera_pose
        camera = pyrender.camera.IntrinsicsCamera(fx=1060.53,
                                                  fy=1060.38,
                                                  cx=camera_center[0],
                                                  cy=camera_center[1])
        light = pyrender.DirectionalLight(color=np.ones(3), intensity=2.0)

        material = pyrender.MetallicRoughnessMaterial(
            metallicFactor=0.0,
            alphaMode='OPAQUE',
            baseColorFactor=(1.0, 1.0, 0.9, 1.0))
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)

        ## rendering body
        img = img.detach().cpu().numpy()
        H, W, _ = img.shape

        scene = pyrender.Scene(bg_color=[0.0, 0.0, 0.0, 0.0],
                               ambient_light=(0.3, 0.3, 0.3))
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)
        # for node in light_nodes:
        #     scene.add_node(node)

        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W,
                                       viewport_height=H,
                                       point_size=1.0)
        color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
        color = color.astype(np.float32) / 255.0

        valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
        input_img = img
        output_img = (color[:, :, :-1] * valid_mask +
                      (1 - valid_mask) * input_img)

        img = pil_img.fromarray((output_img * 255).astype(np.uint8))
        img.save(out_img_fn)

        ##redering body+scene
        body_mesh = pyrender.Mesh.from_trimesh(out_mesh, material=material)
        static_scene = trimesh.load(osp.join(scene_dir, scene_name + '.ply'))
        trans = np.linalg.inv(cam2world)
        static_scene.apply_transform(trans)

        static_scene_mesh = pyrender.Mesh.from_trimesh(static_scene)

        scene = pyrender.Scene()
        scene.add(camera, pose=camera_pose)
        scene.add(light, pose=camera_pose)

        scene.add(static_scene_mesh, 'mesh')
        scene.add(body_mesh, 'mesh')

        r = pyrender.OffscreenRenderer(viewport_width=W, viewport_height=H)
        color, _ = r.render(scene)
        color = color.astype(np.float32) / 255.0
        img = pil_img.fromarray((color * 255).astype(np.uint8))
        img.save(body_scene_rendering_fn)
Beispiel #27
0
    def pose_result(self,
                    verts,
                    pose_params,
                    save_obj,
                    cloth_type=None,
                    obj_dir=None):
        '''
        :param verts: [N, 6890, 3]
        :param pose_params: [N, 72]
        '''
        if verts.shape[0] != 1:  # minimal shape: pose it to every pose
            assert verts.shape[0] == pose_params.shape[
                0]  # otherwise the number of results should equal the number of pose identities

        verts_posed = []

        if save_obj:
            if not exists(obj_dir):
                os.makedirs(obj_dir)
            print('saving results as .obj files to {}...'.format(obj_dir))

        if verts.shape[0] == 1:
            self.smpl_model.v_template[:] = torch.from_numpy(verts[0])
            for i in range(len(pose_params)):
                # model.pose[:] = pose_params[i]
                self.smpl_model.body_pose[:] = torch.from_numpy(
                    pose_params[i][3:])
                self.smpl_model.global_orient[:] = torch.from_numpy(
                    pose_params[i][:3])
                verts_out = self.smpl_model().vertices.detach().cpu().numpy()
                verts_posed.append(verts_out)
                if save_obj:
                    if cloth_type is not None:
                        Mesh(verts_out.squeeze(),
                             self.smpl_model.faces).write_obj(
                                 join(obj_dir,
                                      '{}_{:0>4d}.obj').format(cloth_type, i))
                    else:
                        Mesh(verts_out.squeeze(),
                             self.smpl_model.faces).write_obj(
                                 join(obj_dir, '{:0>4d}.obj').format(i))
        else:
            for i in range(len(verts)):
                self.smpl_model.v_template[:] = torch.from_numpy(verts[i])
                self.smpl_model.body_pose[:] = torch.from_numpy(
                    pose_params[i][3:])
                self.smpl_model.global_orient[:] = torch.from_numpy(
                    pose_params[i][:3])
                verts_out = self.smpl_model().vertices.detach().cpu().numpy()
                verts_posed.append(verts_out)
                if save_obj:
                    if cloth_type is not None:
                        Mesh(verts_out.squeeze(),
                             self.smpl_model.faces).write_obj(
                                 join(obj_dir,
                                      '{}_{:0>4d}.obj').format(cloth_type, i))
                    else:
                        Mesh(verts_out.squeeze(),
                             self.smpl_model.faces).write_obj(
                                 join(obj_dir, '{:0>4d}.obj').format(i))

        return verts_posed
    def _render_sequences_helper(self, video_fname, seq_raw_audio,
                                 seq_processed_audio, seq_template, seq_verts,
                                 condition_idx):
        def add_image_text(img, text):
            font = cv2.FONT_HERSHEY_SIMPLEX
            textsize = cv2.getTextSize(text, font, 1, 2)[0]
            textX = (img.shape[1] - textsize[0]) // 2
            textY = textsize[1] + 10
            cv2.putText(img, '%s' % (text), (textX, textY), font, 1,
                        (0, 0, 255), 2, cv2.LINE_AA)

        num_frames = seq_verts.shape[0]
        tmp_audio_file = tempfile.NamedTemporaryFile(
            'w', suffix='.wav', dir=os.path.dirname(video_fname))
        wavfile.write(tmp_audio_file.name, seq_raw_audio['sample_rate'],
                      seq_raw_audio['audio'])

        tmp_video_file = tempfile.NamedTemporaryFile(
            'w', suffix='.mp4', dir=os.path.dirname(video_fname))
        if int(cv2.__version__[0]) < 3:
            print('cv2 < 3')
            writer = cv2.VideoWriter(tmp_video_file.name,
                                     cv2.cv.CV_FOURCC(*'mp4v'), 25,
                                     (1600, 800), True)
        else:
            print('cv2 >= 3')
            writer = cv2.VideoWriter(tmp_video_file.name,
                                     cv2.VideoWriter_fourcc(*'mp4v'), 25,
                                     (1600, 800), True)

        feed_dict = {
            self.speech_features:
            np.expand_dims(np.stack(seq_processed_audio), -1),
            self.condition_subject_id:
            np.repeat(condition_idx, num_frames),
            self.is_training:
            False,
            self.input_template:
            np.repeat(seq_template[np.newaxis, :, :, np.newaxis],
                      num_frames,
                      axis=0)
        }

        predicted_vertices, predicted_offset = self.session.run(
            [self.output_decoder, self.expression_offset], feed_dict)
        predicted_vertices = np.squeeze(predicted_vertices)
        center = np.mean(seq_verts[0], axis=0)

        for i_frame in range(min(200, num_frames)):
            gt_img = render_mesh_helper(
                Mesh(seq_verts[i_frame], self.template_mesh.f), center)
            # add_image_text(gt_img, 'Captured data')
            pred_img = render_mesh_helper(
                Mesh(predicted_vertices[i_frame], self.template_mesh.f),
                center)
            # add_image_text(pred_img, 'VOCA prediction')
            img = np.hstack((gt_img, pred_img))
            writer.write(img)
        writer.release()

        cmd = (
            'ffmpeg' +
            ' -i {0} -i {1} -vcodec h264 -ac 2 -strict -2 -channel_layout stereo -pix_fmt yuv420p  {2}'
            .format(tmp_audio_file.name, tmp_video_file.name,
                    video_fname)).split()
        call(cmd)
Beispiel #29
0
def main(args):
    if not os.path.exists(args.conf):
        print('Config not found' + args.conf)

    config = read_config(args.conf)
    for k in config.keys():
        print(k, config[k])

    print('Initializing parameters')
    template_file_path = config['template_fname']
    template_mesh = Mesh(filename=template_file_path)

    if args.checkpoint_dir:
        checkpoint_dir = args.checkpoint_dir
    else:
        checkpoint_dir = config['checkpoint_dir']
    if not os.path.exists(checkpoint_dir):
        os.makedirs(checkpoint_dir)

    visualize = config['visualize']
    output_dir = config['visual_output_dir']
    if visualize is True and not output_dir:
        print(
            'No visual output directory is provided. Checkpoint directory will be used to store the visual results'
        )
        output_dir = checkpoint_dir

    if output_dir and not os.path.exists(output_dir):
        os.makedirs(output_dir)

    eval_flag = config['eval']
    lr = config['learning_rate']
    lr_decay = config['learning_rate_decay']
    weight_decay = config['weight_decay']
    total_epochs = config['epoch']
    workers_thread = config['workers_thread']
    opt = config['optimizer']
    batch_size = config['batch_size']
    val_losses, accs, durations = [], [], []

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Generating transforms')
    M, A, D, U = mesh_operations.generate_transform_matrices(
        template_mesh, config['downsampling_factors'])

    D_t = [scipy_to_torch_sparse(d).to(device) for d in D]
    U_t = [scipy_to_torch_sparse(u).to(device) for u in U]
    A_t = [scipy_to_torch_sparse(a).to(device) for a in A]
    num_nodes = [len(M[i].v) for i in range(len(M))]

    print('Loading Dataset')
    if args.data_dir:
        data_dir = args.data_dir
    else:
        data_dir = config['data_dir']

    normalize_transform = Normalize()
    dataset = ComaDataset(data_dir,
                          dtype='train',
                          split=args.split,
                          split_term=args.split_term,
                          pre_transform=normalize_transform)
    dataset_val = ComaDataset(data_dir,
                              dtype='val',
                              split=args.split,
                              split_term=args.split_term,
                              pre_transform=normalize_transform)
    dataset_test = ComaDataset(data_dir,
                               dtype='test',
                               split=args.split,
                               split_term=args.split_term,
                               pre_transform=normalize_transform)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=workers_thread)
    val_loader = DataLoader(dataset_val,
                            batch_size=1,
                            shuffle=True,
                            num_workers=workers_thread)
    test_loader = DataLoader(dataset_test,
                             batch_size=1,
                             shuffle=False,
                             num_workers=workers_thread)

    print('Loading model')
    start_epoch = 1
    coma = Coma(dataset, config, D_t, U_t, A_t, num_nodes)
    if opt == 'adam':
        optimizer = torch.optim.Adam(coma.parameters(),
                                     lr=lr,
                                     weight_decay=weight_decay)
    elif opt == 'sgd':
        optimizer = torch.optim.SGD(coma.parameters(),
                                    lr=lr,
                                    weight_decay=weight_decay,
                                    momentum=0.9)
    else:
        raise Exception('No optimizer provided')

    checkpoint_file = config['checkpoint_file']
    print(checkpoint_file)
    if checkpoint_file:
        checkpoint = torch.load(checkpoint_file)
        start_epoch = checkpoint['epoch_num']
        coma.load_state_dict(checkpoint['state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer'])
        #To find if this is fixed in pytorch
        for state in optimizer.state.values():
            for k, v in state.items():
                if isinstance(v, torch.Tensor):
                    state[k] = v.to(device)
    coma.to(device)

    if eval_flag:
        val_loss = evaluate(coma, output_dir, test_loader, dataset_test,
                            template_mesh, device, visualize)
        print('val loss', val_loss)
        return

    best_val_loss = float('inf')
    val_loss_history = []

    from datetime import datetime
    current_time = datetime.now().strftime('%b%d_%H-%M-%S')
    log_dir = os.path.join('runs/clae_dxo_de', current_time)
    writer = SummaryWriter(log_dir + 'ds2_lr0.04_z8')

    for epoch in range(start_epoch, total_epochs + 1):
        print("Training for epoch ", epoch)
        train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        #train_loss = train(coma, train_loader, len(dataset), optimizer, device)
        val_loss = evaluate(coma,
                            output_dir,
                            val_loader,
                            dataset_val,
                            template_mesh,
                            device,
                            epoch,
                            visualize=visualize)

        writer.add_scalar('data/train_loss', train_loss, epoch)
        writer.add_scalar('data/val_loss', val_loss, epoch)

        print('epoch ', epoch, ' Train loss ', train_loss, ' Val loss ',
              val_loss)
        if val_loss < best_val_loss:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)
            best_val_loss = val_loss

        if epoch == total_epochs or epoch % 100 == 0:
            save_model(coma, optimizer, epoch, train_loss, val_loss,
                       checkpoint_dir)

        val_loss_history.append(val_loss)
        val_losses.append(best_val_loss)

        if opt == 'sgd':
            adjust_learning_rate(optimizer, lr_decay)

    if torch.cuda.is_available():
        torch.cuda.synchronize()

    writer.close()
def render_mesh_helper(mesh,
                       t_center,
                       rot=np.zeros(3),
                       tex_img=None,
                       v_colors=None,
                       errors=None,
                       error_unit='m',
                       min_dist_in_mm=0.0,
                       max_dist_in_mm=3.0,
                       z_offset=0):
    camera_params = {
        'c': np.array([400, 400]),
        'k': np.array([-0.19816071, 0.92822711, 0, 0, 0]),
        'f': np.array([4754.97941935 / 2, 4754.97941935 / 2])
    }

    frustum = {'near': 0.01, 'far': 3.0, 'height': 800, 'width': 800}

    mesh_copy = Mesh(mesh.v, mesh.f)
    mesh_copy.v[:] = cv2.Rodrigues(rot)[0].dot(
        (mesh_copy.v - t_center).T).T + t_center

    texture_rendering = tex_img is not None and hasattr(
        mesh, 'vt') and hasattr(mesh, 'ft')
    if texture_rendering:
        intensity = 0.5
        tex = pyrender.Texture(source=tex_img, source_channels='RGB')
        material = pyrender.material.MetallicRoughnessMaterial(
            baseColorTexture=tex)

        # Workaround as pyrender requires number of vertices and uv coordinates to be the same
        temp_filename = '%s.obj' % next(tempfile._get_candidate_names())
        mesh.write_obj(temp_filename)
        tri_mesh = trimesh.load(temp_filename, process=False)
        try:
            os.remove(temp_filename)
        except:
            print('Failed deleting temporary file - %s' % temp_filename)
        render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, material=material)
    elif errors is not None:
        intensity = 0.5
        unit_factor = get_unit_factor('mm') / get_unit_factor(error_unit)
        errors = unit_factor * errors

        norm = mpl.colors.Normalize(vmin=min_dist_in_mm, vmax=max_dist_in_mm)
        cmap = cm.get_cmap(name='jet')
        colormapper = cm.ScalarMappable(norm=norm, cmap=cmap)
        rgba_per_v = colormapper.to_rgba(errors)
        rgb_per_v = rgba_per_v[:, 0:3]
    elif v_colors is not None:
        intensity = 0.5
        rgb_per_v = v_colors
    else:
        intensity = 1.5
        rgb_per_v = None

    if not texture_rendering:
        tri_mesh = trimesh.Trimesh(vertices=mesh_copy.v,
                                   faces=mesh_copy.f,
                                   vertex_colors=rgb_per_v)
        render_mesh = pyrender.Mesh.from_trimesh(tri_mesh, smooth=True)

    scene = pyrender.Scene(ambient_light=[.2, .2, .2],
                           bg_color=[255, 255, 255])
    camera = pyrender.IntrinsicsCamera(fx=camera_params['f'][0],
                                       fy=camera_params['f'][1],
                                       cx=camera_params['c'][0],
                                       cy=camera_params['c'][1],
                                       znear=frustum['near'],
                                       zfar=frustum['far'])

    scene.add(render_mesh, pose=np.eye(4))

    camera_pose = np.eye(4)
    camera_pose[:3, 3] = np.array([0, 0, 1.0 - z_offset])
    scene.add(camera,
              pose=[[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]])

    angle = np.pi / 6.0
    pos = camera_pose[:3, 3]
    light_color = np.array([1., 1., 1.])
    light = pyrender.PointLight(color=light_color, intensity=intensity)

    light_pose = np.eye(4)
    light_pose[:3, 3] = pos
    scene.add(light, pose=light_pose.copy())

    light_pose[:3, 3] = cv2.Rodrigues(np.array([angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3, 3] = cv2.Rodrigues(np.array([-angle, 0, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3, 3] = cv2.Rodrigues(np.array([0, -angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    light_pose[:3, 3] = cv2.Rodrigues(np.array([0, angle, 0]))[0].dot(pos)
    scene.add(light, pose=light_pose.copy())

    flags = pyrender.RenderFlags.SKIP_CULL_FACES
    try:
        r = pyrender.OffscreenRenderer(viewport_width=frustum['width'],
                                       viewport_height=frustum['height'])
        color, _ = r.render(scene, flags=flags)
    except:
        print('pyrender: Failed rendering frame')
        color = np.zeros((frustum['height'], frustum['width'], 3),
                         dtype='uint8')

    return color[..., ::-1]