示例#1
0
    def params_rnet(self, dorig):
        rh_mesh = Meshes(verts=dorig['verts_rhand_f'], faces=self.rh_f).to(
            self.device).verts_normals_packed().view(-1, 778, 3)
        rh_mesh_gt = Meshes(verts=dorig['verts_rhand'], faces=self.rh_f).to(
            self.device).verts_normals_packed().view(-1, 778, 3)

        o2h_signed, h2o, _ = point2point_signed(dorig['verts_rhand_f'],
                                                dorig['verts_object'], rh_mesh)
        o2h_signed_gt, h2o_gt, _ = point2point_signed(dorig['verts_rhand'],
                                                      dorig['verts_object'],
                                                      rh_mesh_gt)

        h2o = h2o.abs()
        h2o_gt = h2o_gt.abs()

        return {'h2o_dist': h2o, 'h2o_gt': h2o_gt, 'o2h_gt': o2h_signed_gt}
示例#2
0
    def loss_rnet(self, dorig, drec, ds_name='train'):

        out_put = self.rhm_train(**drec)
        verts_rhand = out_put.vertices

        rh_mesh = Meshes(verts=verts_rhand, faces=self.rh_f).to(
            self.device).verts_normals_packed().view(-1, 778, 3)
        h2o_gt = dorig['h2o_gt']
        o2h_signed, h2o, _ = point2point_signed(verts_rhand,
                                                dorig['verts_object'], rh_mesh)
        ######### dist loss
        loss_dist_h = 35 * (1. - self.cfg.kl_coef) * torch.mean(
            torch.einsum('ij,j->ij', torch.abs(h2o.abs() - h2o_gt.abs()),
                         self.v_weights2))
        ########## verts loss
        loss_mesh_rec_w = 20 * (1. - self.cfg.kl_coef) * torch.mean(
            torch.einsum(
                'ijk,j->ijk', torch.abs(
                    (dorig['verts_rhand'] - verts_rhand)), self.v_weights2))
        ########## edge loss
        loss_edge = 10 * (1. - self.cfg.kl_coef) * self.LossL1(
            self.edges_for(verts_rhand, self.vpe),
            self.edges_for(dorig['verts_rhand'], self.vpe))
        ##########

        loss_dict = {
            'loss_edge_r': loss_edge,
            'loss_mesh_rec_r': loss_mesh_rec_w,
            'loss_dist_h_r': loss_dist_h,
        }

        loss_total = torch.stack(list(loss_dict.values())).sum()
        loss_dict['loss_total'] = loss_total

        return loss_total, loss_dict
示例#3
0
    def forward(self, h2o_dist, fpose_rhand_rotmat_f, trans_rhand_f,
                global_orient_rhand_rotmat_f, verts_object, **kwargs):

        bs = h2o_dist.shape[0]
        init_pose = fpose_rhand_rotmat_f[..., :2].reshape(bs, -1)
        init_rpose = global_orient_rhand_rotmat_f[..., :2].reshape(bs, -1)
        init_pose = torch.cat([init_rpose, init_pose], dim=1)
        init_trans = trans_rhand_f

        for i in range(self.n_iters):

            if i != 0:
                hand_parms = parms_decode(init_pose, init_trans)
                verts_rhand = self.rhm_train(**hand_parms).vertices
                _, h2o_dist, _ = point2point_signed(verts_rhand, verts_object)

            h2o_dist = self.bn1(h2o_dist)
            X0 = torch.cat([h2o_dist, init_pose, init_trans], dim=1)
            X = self.rb1(X0)
            X = self.dout(X)
            X = self.rb2(torch.cat([X, X0], dim=1))
            X = self.dout(X)
            X = self.rb3(torch.cat([X, X0], dim=1))
            X = self.dout(X)

            pose = self.out_p(X)
            trans = self.out_t(X)

            init_trans = init_trans + trans
            init_pose = init_pose + pose

        hand_parms = parms_decode(init_pose, init_trans)
        return hand_parms
示例#4
0
def vis_results(dorig, coarse_net, refine_net, rh_model , save=False, save_dir = None):

    with torch.no_grad():
        imw, imh = 1920, 780
        cols = len(dorig['bps_object'])
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        mvs = MeshViewers(window_width=imw, window_height=imh, shape=[1, cols], keepalive=True)

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist']= h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        verts_rh_gen_rnet = rh_model(**drec_rnet).vertices


        for cId in range(0, len(dorig['bps_object'])):
            try:
                from copy import deepcopy
                meshes = deepcopy(dorig['mesh_object'])
                obj_mesh = meshes[cId]
            except:
                obj_mesh = points_to_spheres(to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['green'])

            hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]), f=rh_model.faces, vc=name_to_rgb['pink'])
            hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]), f=rh_model.faces, vc=name_to_rgb['gray'])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_cnet.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

            hand_mesh_gen_cnet.reset_face_normals()
            hand_mesh_gen_rnet.reset_face_normals()

            # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True)
            mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True)

            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.write_ply(filename=save_path + '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh[0].write_ply(filename=save_path + '/obj_mesh_%d.ply' % cId)
示例#5
0
def get_meshes(dorig, coarse_net, refine_net, rh_model, save=False, save_dir=None):
    with torch.no_grad():

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet, dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist'] = h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        verts_rh_gen_rnet = rh_model(**drec_rnet).vertices

        gen_meshes = []
        for cId in range(0, len(dorig['bps_object'])):
            try:
                obj_mesh = dorig['mesh_object'][cId]
            except:
                obj_mesh = points2sphere(points=to_cpu(dorig['verts_object'][cId]), radius=0.002, vc=name_to_rgb['yellow'])

            hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]), faces=rh_model.faces, vc=[245, 191, 177])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

            gen_meshes.append([obj_mesh, hand_mesh_gen_rnet])
            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.export(filename=save_path + '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh.export(filename=save_path + '/obj_mesh_%d.ply' % cId)

        return gen_meshes
示例#6
0
    def loss_cnet(self, dorig, drec, ds_name='train'):

        device = dorig['verts_rhand'].device
        dtype = dorig['verts_rhand'].dtype

        q_z = torch.distributions.normal.Normal(drec['mean'], drec['std'])

        out_put = self.rhm_train(**drec)
        verts_rhand = out_put.vertices

        rh_mesh = Meshes(verts=verts_rhand, faces=self.rh_f).to(
            self.device).verts_normals_packed().view(-1, 778, 3)
        rh_mesh_gt = Meshes(verts=dorig['verts_rhand'], faces=self.rh_f).to(
            self.device).verts_normals_packed().view(-1, 778, 3)

        o2h_signed, h2o, _ = point2point_signed(verts_rhand,
                                                dorig['verts_object'], rh_mesh)
        o2h_signed_gt, h2o_gt, o2h_idx = point2point_signed(
            dorig['verts_rhand'], dorig['verts_object'], rh_mesh_gt)

        # addaptive weight for penetration and contact verts
        w_dist = (o2h_signed_gt < 0.01) * (o2h_signed_gt > -0.005)
        w_dist_neg = o2h_signed < 0.
        w = self.w_dist.clone()
        w[~w_dist] = .1  # less weight for far away vertices
        w[w_dist_neg] = 1.5  # more weight for penetration
        ######### dist loss
        loss_dist_h = 35 * (1. - self.cfg.kl_coef) * torch.mean(
            torch.einsum('ij,j->ij', torch.abs(h2o.abs() - h2o_gt.abs()),
                         self.v_weights2))
        loss_dist_o = 30 * (1. - self.cfg.kl_coef) * torch.mean(
            torch.einsum('ij,ij->ij', torch.abs(o2h_signed - o2h_signed_gt),
                         w))
        ########## verts loss
        loss_mesh_rec_w = 35 * (1. - self.cfg.kl_coef) * torch.mean(
            torch.einsum(
                'ijk,j->ijk', torch.abs(
                    (dorig['verts_rhand'] - verts_rhand)), self.v_weights))
        ########## edge loss
        loss_edge = 30 * (1. - self.cfg.kl_coef) * self.LossL1(
            self.edges_for(verts_rhand, self.vpe),
            self.edges_for(dorig['verts_rhand'], self.vpe))
        ########## KL loss
        p_z = torch.distributions.normal.Normal(
            loc=torch.tensor(np.zeros([self.cfg.batch_size, self.cfg.latentD]),
                             requires_grad=False).to(device).type(dtype),
            scale=torch.tensor(np.ones([self.cfg.batch_size,
                                        self.cfg.latentD]),
                               requires_grad=False).to(device).type(dtype))
        loss_kl = self.cfg.kl_coef * torch.mean(
            torch.sum(torch.distributions.kl.kl_divergence(q_z, p_z), dim=[1]))
        ##########

        loss_dict = {
            'loss_kl': loss_kl,
            'loss_edge': loss_edge,
            'loss_mesh_rec': loss_mesh_rec_w,
            'loss_dist_h': loss_dist_h,
            'loss_dist_o': loss_dist_o,
        }

        loss_total = torch.stack(list(loss_dict.values())).sum()
        loss_dict['loss_total'] = loss_total

        return loss_total, loss_dict
示例#7
0
def vis_results(ho,
                dorig,
                coarse_net,
                refine_net,
                rh_model,
                save=False,
                save_dir=None,
                rh_model_pkl=None,
                vis=True):

    # with torch.no_grad():
    imw, imh = 1920, 780
    cols = len(dorig['bps_object'])
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device('cpu')

    if vis:
        mvs = MeshViewers(window_width=imw,
                          window_height=imh,
                          shape=[1, cols],
                          keepalive=True)

    # drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
    #
    # for k in drec_cnet.keys():
    #     print('drec cnet', k, drec_cnet[k].shape)

    # verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    drec_cnet = {}

    hand_pose_in = torch.Tensor(ho.hand_pose[3:]).unsqueeze(0)
    mano_out_1 = rh_model_pkl(hand_pose=hand_pose_in)
    hand_pose_in = mano_out_1.hand_pose

    mTc = torch.Tensor(ho.hand_mTc)
    approx_global_orient = rotmat2aa(mTc[:3, :3].unsqueeze(0))

    if torch.isnan(approx_global_orient).any():  # Using honnotate?
        approx_global_orient = torch.Tensor(ho.hand_pose[:3]).unsqueeze(0)

    approx_global_orient = approx_global_orient.squeeze(1).squeeze(1)
    approx_trans = mTc[:3, 3].unsqueeze(0)

    target_verts = torch.Tensor(ho.hand_verts).unsqueeze(0)

    pose, trans, rot = util.opt_hand(rh_model, target_verts, hand_pose_in,
                                     approx_trans, approx_global_orient)

    # drec_cnet['hand_pose'] = torch.einsum('bi,ij->bj', [hand_pose_in, rh_model_pkl.hand_components])
    drec_cnet['transl'] = trans
    drec_cnet['global_orient'] = rot
    drec_cnet['hand_pose'] = pose

    verts_rh_gen_cnet = rh_model(**drec_cnet).vertices

    _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                   dorig['verts_object'].to(device))

    drec_cnet['trans_rhand_f'] = drec_cnet['transl']
    drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
        drec_cnet['global_orient']).view(-1, 3, 3)
    drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(drec_cnet['hand_pose']).view(
        -1, 15, 3, 3)
    drec_cnet['verts_object'] = dorig['verts_object'].to(device)
    drec_cnet['h2o_dist'] = h2o.abs()

    print(
        'Hand fitting err',
        np.linalg.norm(
            verts_rh_gen_cnet.squeeze().detach().numpy() - ho.hand_verts, 2,
            1).mean())
    orig_obj = dorig['mesh_object'][0].v
    # print(orig_obj.shape, orig_obj)
    # print('Obj fitting err', np.linalg.norm(orig_obj - ho.obj_verts, 2, 1).mean())

    drec_rnet = refine_net(**drec_cnet)
    mano_out = rh_model(**drec_rnet)
    verts_rh_gen_rnet = mano_out.vertices
    joints_out = mano_out.joints

    if vis:
        for cId in range(0, len(dorig['bps_object'])):
            try:
                from copy import deepcopy
                meshes = deepcopy(dorig['mesh_object'])
                obj_mesh = meshes[cId]
            except:
                obj_mesh = points_to_spheres(to_cpu(
                    dorig['verts_object'][cId]),
                                             radius=0.002,
                                             vc=name_to_rgb['green'])

            hand_mesh_gen_cnet = Mesh(v=to_cpu(verts_rh_gen_cnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['pink'])
            hand_mesh_gen_rnet = Mesh(v=to_cpu(verts_rh_gen_rnet[cId]),
                                      f=rh_model.faces,
                                      vc=name_to_rgb['gray'])

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_cnet.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)
                # print('rotmat', rotmat)

            hand_mesh_gen_cnet.reset_face_normals()
            hand_mesh_gen_rnet.reset_face_normals()

            # mvs[0][cId].set_static_meshes([hand_mesh_gen_cnet] + obj_mesh, blocking=True)
            # mvs[0][cId].set_static_meshes([hand_mesh_gen_rnet,obj_mesh], blocking=True)
            mvs[0][cId].set_static_meshes(
                [hand_mesh_gen_rnet, hand_mesh_gen_cnet, obj_mesh],
                blocking=True)

            if save:
                save_path = os.path.join(save_dir, str(cId))
                makepath(save_path)
                hand_mesh_gen_rnet.write_ply(filename=save_path +
                                             '/rh_mesh_gen_%d.ply' % cId)
                obj_mesh[0].write_ply(filename=save_path +
                                      '/obj_mesh_%d.ply' % cId)

    return verts_rh_gen_rnet, joints_out
示例#8
0
def get_meshes(dorig,
               coarse_net,
               refine_net,
               rh_model,
               save=False,
               save_dir=None):
    with torch.no_grad():

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        drec_cnet = coarse_net.sample_poses(dorig['bps_object'])
        output = rh_model(**drec_cnet)
        verts_rh_gen_cnet = output.vertices

        _, h2o, _ = point2point_signed(verts_rh_gen_cnet,
                                       dorig['verts_object'].to(device))

        drec_cnet['trans_rhand_f'] = drec_cnet['transl']
        drec_cnet['global_orient_rhand_rotmat_f'] = aa2rotmat(
            drec_cnet['global_orient']).view(-1, 3, 3)
        drec_cnet['fpose_rhand_rotmat_f'] = aa2rotmat(
            drec_cnet['hand_pose']).view(-1, 15, 3, 3)
        drec_cnet['verts_object'] = dorig['verts_object'].to(device)
        drec_cnet['h2o_dist'] = h2o.abs()

        drec_rnet = refine_net(**drec_cnet)
        output = rh_model(**drec_rnet)
        print("hand shape {} should be idtenty".format(output.betas))
        verts_rh_gen_rnet = output.vertices

        # Reorder joints to match visualization utilities (joint_mapper) (TODO)
        joints_rh_gen_rnet = output.joints  # [:, [0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8, 9, 20]]
        transforms_rh_gen_rnet = output.transforms  # [:, [0, 13, 14, 15, 1, 2, 3, 4, 5, 6, 10, 11, 12, 7, 8, 9]]
        joints_rh_gen_rnet = to_cpu(joints_rh_gen_rnet)
        transforms_rh_gen_rnet = to_cpu(transforms_rh_gen_rnet)

        gen_meshes = []
        for cId in range(0, len(dorig['bps_object'])):
            try:
                obj_mesh = dorig['mesh_object'][cId]
            except:
                obj_mesh = points2sphere(points=to_cpu(
                    dorig['verts_object'][cId]),
                                         radius=0.002,
                                         vc=[145, 191, 219])

            hand_mesh_gen_rnet = Mesh(vertices=to_cpu(verts_rh_gen_rnet[cId]),
                                      faces=rh_model.faces,
                                      vc=[145, 191, 219])
            hand_joint_gen_rnet = joints_rh_gen_rnet[cId]
            hand_transform_gen_rnet = transforms_rh_gen_rnet[cId]

            if 'rotmat' in dorig:
                rotmat = dorig['rotmat'][cId].T
                obj_mesh = obj_mesh.rotate_vertices(rotmat)
                hand_mesh_gen_rnet.rotate_vertices(rotmat)

                hand_joint_gen_rnet = hand_joint_gen_rnet @ rotmat.T
                hand_transform_gen_rnet[:, :, :3, :3] = np.matmul(
                    rotmat[None, ...], hand_transform_gen_rnet[:, :, :3, :3])

            gen_meshes.append([obj_mesh, hand_mesh_gen_rnet])
            if save:
                makepath(save_dir)
                print("saving dir {}".format(save_dir))
                np.save(save_dir + '/joints_%d.npy' % cId, hand_joint_gen_rnet)
                np.save(save_dir + '/trans_%d.npy' % cId,
                        hand_transform_gen_rnet)

        return gen_meshes