def deform_mesh_by_closest_vertices(mesh,
                                    src_mesh,
                                    tar_mesh,
                                    device=torch.device("cpu")):
    """
    mesh の頂点と最も近い src_mesh の頂点番号を元に、mesh を tar_mesh の形状に変更する。
    [ToDo] 複数バッチ処理に対応
    """
    if (len(mesh.verts_packed().shape) == 2):
        batch_size = 1
    else:
        batch_size = mesh.verts_packed().shape[0]

    # pytorch3d -> psbody.mesh への変換
    mesh_face_pytorch3d = mesh.faces_packed()
    mesh = Mesh(mesh.verts_packed().detach().cpu().numpy(),
                mesh.faces_packed().detach().cpu().numpy())
    src_mesh = Mesh(src_mesh.verts_packed().detach().cpu().numpy(),
                    src_mesh.faces_packed().detach().cpu().numpy())
    tar_mesh = Mesh(tar_mesh.verts_packed().detach().cpu().numpy(),
                    tar_mesh.faces_packed().detach().cpu().numpy())

    # verts_idx : 最も近い頂点番号
    verts_idx, _ = src_mesh.closest_vertices(mesh.v)
    verts_idx = np.array(verts_idx)
    new_mesh_verts = mesh.v - src_mesh.v[verts_idx] + tar_mesh.v[verts_idx]
    #print( "verts_idx : ", verts_idx )
    #print( "new_mesh_verts.shape : ", new_mesh_verts.shape )    # (7702, 3)
    #print( "mesh.f.shape : ", mesh.f.shape )                    # (15180, 3)

    # psbody.mesh -> pytorch3d への変換
    if (batch_size == 1):
        new_mesh = Meshes(
            torch.from_numpy(
                new_mesh_verts).requires_grad_(False).float().unsqueeze(0),
            mesh_face_pytorch3d.unsqueeze(0)).to(device)
    else:
        NotImplementedError()

    return new_mesh
Exemple #2
0
    def __getitem__(self, idx):
        path = self.data[idx]
        name = split(path)[1]

        input_smpl = Mesh(filename=join(path, name + '_smpl.obj'))
        if self.naked:
            input_scan = Mesh(filename=join(path, name + '_smpl.obj'))
        else:
            input_scan = Mesh(filename=join(path, name + '.obj'))
        temp = trimesh.Trimesh(vertices=input_scan.v, faces=input_scan.f)
        points = temp.sample(NUM_POINTS)

        if self.augment:
            rot = self.get_rnd_rotations()
            points = rot.apply(points)
            input_smpl.v = rot.apply(input_smpl.v)

        ind, _ = input_smpl.closest_vertices(points)
        part_labels = self.smpl_parts[np.array(ind)]
        correspondences = self.map_mesh_points_to_reference(
            points, input_smpl, self.ref_smpl.r)

        if self.mode == 'train':
            return {
                'scan': points.astype('float32'),
                'correspondences': correspondences.astype('float32'),
                'part_labels': part_labels.astype('float32'),
                'name': path
            }

        vc = self.map_vitruvian_vertex_color(points, input_smpl)
        return {
            'scan': points.astype('float32'),
            'smpl': input_smpl.v.astype('float32'),
            'correspondences': correspondences.astype('float32'),
            'part_labels': part_labels.astype('float32'),
            'scan_vc': vc,
            'name': path
        }
Exemple #3
0
    def __getitem__(self, idx):
        path = self.data[idx]
        name = split(path)[1]

        input_smpl = Mesh(filename=join(path, name + '_smpl.obj'))
        if self.naked:
            input_scan = Mesh(filename=join(path, name + '_smpl.obj'))
        else:
            input_scan = Mesh(filename=join(path, name + '.obj'))
        temp = trimesh.Trimesh(vertices=input_scan.v, faces=input_scan.f)
        points = temp.sample(NUM_POINTS)

        if self.augment:
            rot = self.get_rnd_rotations()
            points = rot.apply(points)
            input_smpl.v = rot.apply(input_smpl.v)

        ind, _ = input_smpl.closest_vertices(points)
        part_labels = self.smpl_parts[np.array(ind)]
        correspondences = self.map_mesh_points_to_reference(
            points, input_smpl, self.ref_smpl.r)

        # Load cached SMPL params
        cache_list = []
        if self.cache_suffix is not None:
            cache_list = sorted(glob(join(path, self.cache_suffix, '*.pkl')))
        if len(cache_list) > 0:
            smpl_dict = pkl.load(open(cache_list[-1], 'rb'),
                                 encoding='latin-1')
            pose = smpl_dict['pose']
            betas = smpl_dict['betas']
            trans = smpl_dict['trans']
            # print('Loading from cache ', cache_list[-1])
        else:
            pose = np.zeros((72, ))
            betas = np.zeros((10, ))
            trans = np.zeros((3, ))

        if self.mode == 'train':
            return {
                'scan': points.astype('float32'),
                'correspondences': correspondences.astype('float32'),
                'part_labels': part_labels.astype('float32'),
                'pose': pose.astype('float32'),
                'betas': betas.astype('float32'),
                'trans': trans.astype('float32'),
                'name': path
            }

        vc = self.map_vitruvian_vertex_color(points, input_smpl)
        return {
            'scan': points.astype('float32'),
            'smpl': input_smpl.v.astype('float32'),
            'correspondences': correspondences.astype('float32'),
            'part_labels': part_labels.astype('float32'),
            'pose': pose.astype('float32'),
            'betas': betas.astype('float32'),
            'trans': trans.astype('float32'),
            'scan_vc': vc,
            'name': path
        }