Exemplo n.º 1
0
 def loadHierarchy_recur(self, node, lines, joint_pos):
     for li in lines:
         if li.split()[0] == 'hier' and li.split()[1] == node.name:
             pos = joint_pos[li.split()[2]]
             ch_node = TreeNode(li.split()[2], tuple(pos))
             node.children.append(ch_node)
             ch_node.parent = node
             self.loadHierarchy_recur(ch_node, lines, joint_pos)
Exemplo n.º 2
0
def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename):
    """
    Predict skeleton structure based on joints
    :param input_data: wrapped data
    :param vox: voxelized mesh
    :param root_pred_net: network to predict root
    :param bone_pred_net: network to predict pairwise connectivity cost
    :param mesh_filename: meshfilename for debugging
    :return: predicted skeleton structure
    """
    root_id = getInitId(input_data, root_pred_net)
    pred_joints = input_data.y[:input_data.num_joint[0]].data.cpu().numpy()

    with torch.no_grad():
        connect_prob, _ = bone_pred_net(input_data)
        connect_prob = torch.sigmoid(connect_prob)
    pair_idx = input_data.pairs.long().data.cpu().numpy()
    prob_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
    prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
    prob_matrix = prob_matrix + prob_matrix.transpose()
    cost_matrix = -np.log(prob_matrix + 1e-10)
    cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)

    pred_skel = Info()
    parent, key = primMST_symmetry(cost_matrix, root_id, pred_joints)
    for i in range(len(parent)):
        if parent[i] == -1:
            pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
            break
    loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
    pred_skel.joint_pos = pred_skel.get_joint_dict()
    #show_mesh_vox(mesh_filename, vox, pred_skel.root)
    img = show_obj_skel(mesh_filename, pred_skel.root)
    return pred_skel
Exemplo n.º 3
0
def add_duplicate_joints(skel):
    this_level = [skel.root]
    while this_level:
        next_level = []
        for p_node in this_level:
            if len(p_node.children) > 1:
                new_children = []
                for dup_id in range(len(p_node.children)):
                    p_node_new = TreeNode(
                        p_node.name + '_dup_{:d}'.format(dup_id), p_node.pos)
                    p_node_new.overlap = True
                    p_node_new.parent = p_node
                    p_node_new.children = [p_node.children[dup_id]]
                    # for user interaction, we move overlapping joints a bit to its children
                    p_node_new.pos = np.array(
                        p_node_new.pos) + 0.03 * np.linalg.norm(
                            np.array(p_node.children[dup_id].pos) -
                            np.array(p_node_new.pos))
                    p_node_new.pos = (p_node_new.pos[0], p_node_new.pos[1],
                                      p_node_new.pos[2])
                    p_node.children[dup_id].parent = p_node_new
                    new_children.append(p_node_new)
                p_node.children = new_children
            p_node.overlap = False
            next_level += p_node.children
        this_level = next_level
    return skel
Exemplo n.º 4
0
 def loadSkel_recur(self, node, lines, has_order):
     if has_order:
         ch_queue = Q.PriorityQueue()
         for li in lines:
             words = li.split()
             if words[5] == node.name:
                 ch_queue.put((int(li.split()[6]), li))
         while not ch_queue.empty():
             item = ch_queue.get()
             li = item[1]
             ch_node = TreeNode(li.split()[1],
                                (float(li.split()[2]), float(
                                    li.split()[3]), float(li.split()[4])))
             ch_node.order = int(li.split()[6])
             node.children.append(ch_node)
             ch_node.parent = node
             self.loadSkel_recur(ch_node, lines, has_order)
     else:
         for li in lines:
             words = li.split()
             if words[5] == node.name:
                 ch_node = TreeNode(
                     words[1],
                     (float(words[2]), float(words[3]), float(words[4])))
                 node.children.append(ch_node)
                 ch_node.parent = node
                 self.loadSkel_recur(ch_node, lines, has_order)
Exemplo n.º 5
0
def loadSkel_recur(p_node, parent_id, joint_name, joint_pos, parent):
    """
    Converst prim algorithm result to our skel/info format recursively
    :param p_node: Root node
    :param parent_id: parent name of current step of recursion.
    :param joint_name: list of joint names
    :param joint_pos: joint positions
    :param parent: parent index returned by prim alg.
    :return: p_node (root) will be expanded to linked with all joints
    """
    for i in range(len(parent)):
        if parent[i] == parent_id:
            if joint_name is not None:
                ch_node = TreeNode(joint_name[i], tuple(joint_pos[i]))
            else:
                ch_node = TreeNode('joint_{}'.format(i), tuple(joint_pos[i]))
            p_node.children.append(ch_node)
            ch_node.parent = p_node
            loadSkel_recur(ch_node, i, joint_name, joint_pos, parent)
Exemplo n.º 6
0
def run_mst_generate(args):
    """
    generate skeleton in batch
    :param args: input folder path and data folder path
    """
    test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'),
                           dtype=np.int)
    root_select_model = ROOTNET()
    root_select_model.to(device)
    root_select_model.eval()
    root_checkpoint = torch.load(args.rootnet)
    root_select_model.load_state_dict(root_checkpoint['state_dict'])
    connectivity_model = PairCls()
    connectivity_model.to(device)
    connectivity_model.eval()
    conn_checkpoint = torch.load(args.bonenet)
    connectivity_model.load_state_dict(conn_checkpoint['state_dict'])

    for model_id in test_list:
        print(model_id)
        pred_joints, vox = predict_joints(model_id, args)
        mesh_filename = os.path.join(args.dataset_folder,
                                     'obj_remesh/{:d}.obj'.format(model_id))
        mesh = o3d.io.read_triangle_mesh(mesh_filename)
        surface_geodesic = calc_surface_geodesic(mesh)
        data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
        root_id = getInitId(data, root_select_model)
        with torch.no_grad():
            cost_matrix, _ = connectivity_model.forward(data)
            connect_prob = torch.sigmoid(cost_matrix)
        pair_idx = data.pairs.long().data.cpu().numpy()
        cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
        cost_matrix[pair_idx[:, 0],
                    pair_idx[:,
                             1]] = connect_prob.data.cpu().numpy().squeeze()
        cost_matrix = cost_matrix + cost_matrix.transpose()
        cost_matrix = -np.log(cost_matrix + 1e-10)
        #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
        cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints,
                                                     vox)

        skel = Skel()
        parent, key, root_id = primMST_symmetry(cost_matrix, root_id,
                                                pred_joints)
        for i in range(len(parent)):
            if parent[i] == -1:
                skel.root = TreeNode('root', tuple(pred_joints[i]))
                break
        loadSkel_recur(skel.root, i, None, pred_joints, parent)
        img = show_obj_skel(mesh_filename, skel.root)
        cv2.imwrite(
            os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)),
            img[:, :, ::-1])
        skel.save(
            os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))
Exemplo n.º 7
0
 def load(self, filename):
     with open(filename, 'r') as fin:
         lines = fin.readlines()
     for li in lines:
         words = li.split()
         if words[5] == "None":
             self.root = TreeNode(
                 words[1],
                 (float(words[2]), float(words[3]), float(words[4])))
             if len(words) == 7:
                 has_order = True
                 self.root.order = int(words[6])
             else:
                 has_order = False
             break
     self.loadSkel_recur(self.root, lines, has_order)
Exemplo n.º 8
0
 def load(self, filename):
     with open(filename, 'r') as f_txt:
         lines = f_txt.readlines()
     for line in lines:
         word = line.split()
         if word[0] == 'joints':
             self.joint_pos[word[1]] = [
                 float(word[2]),
                 float(word[3]),
                 float(word[4])
             ]
         elif word[0] == 'root':
             root_pos = self.joint_pos[word[1]]
             self.root = TreeNode(word[1],
                                  (root_pos[0], root_pos[1], root_pos[2]))
         elif word[0] == 'skin':
             skin_item = word[1:]
             self.joint_skin.append(skin_item)
     self.loadHierarchy_recur(self.root, lines, self.joint_pos)
Exemplo n.º 9
0
for i, character in enumerate(characters_list):
    # if i != args.character_idx:
    #     continue

    for motion in motions_list[args.start:args.last]:
        joint_pos_file = os.path.join(joint_log, 'test', character,
                                      '%s_joint.npy' % (motion))
        joint_pos = np.load(joint_pos_file, allow_pickle=True).item()
        joint_result = joint_pos[prediction_method]

        # save skeleton
        pred_skel = Info()
        nodes = []
        for joint_index, joint_pos in enumerate(joint_result):
            nodes.append(TreeNode(name=joint_name[joint_index], pos=joint_pos))

        pred_skel.root = nodes[0]
        for parent, children in enumerate(tree):
            for child in children:
                nodes[parent].children.append(nodes[child])
                nodes[child].parent = nodes[parent]

        # calculate volumetric geodesic distance
        bones, _, _ = get_bones(pred_skel)
        mesh_filename = os.path.join(data_dir, 'objs',
                                     character + '/' + motion + '.obj')
        # mesh_filename = os.path.join(data_dir, 'test_objs', character + '_' + motion + '.obj')
        mesh = o3d.io.read_triangle_mesh(mesh_filename)
        mesh.compute_vertex_normals()
        mesh_v = np.asarray(mesh.vertices)