Example #1
0
def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename):
    """
    Predict skeleton structure based on joints
    :param input_data: wrapped data
    :param vox: voxelized mesh
    :param root_pred_net: network to predict root
    :param bone_pred_net: network to predict pairwise connectivity cost
    :param mesh_filename: meshfilename for debugging
    :return: predicted skeleton structure
    """
    root_id = getInitId(input_data, root_pred_net)
    pred_joints = input_data.y[:input_data.num_joint[0]].data.cpu().numpy()

    with torch.no_grad():
        connect_prob, _ = bone_pred_net(input_data)
        connect_prob = torch.sigmoid(connect_prob)
    pair_idx = input_data.pairs.long().data.cpu().numpy()
    prob_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
    prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
    prob_matrix = prob_matrix + prob_matrix.transpose()
    cost_matrix = -np.log(prob_matrix + 1e-10)
    cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)

    pred_skel = Info()
    parent, key = primMST_symmetry(cost_matrix, root_id, pred_joints)
    for i in range(len(parent)):
        if parent[i] == -1:
            pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
            break
    loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
    pred_skel.joint_pos = pred_skel.get_joint_dict()
    #show_mesh_vox(mesh_filename, vox, pred_skel.root)
    img = show_obj_skel(mesh_filename, pred_skel.root)
    return pred_skel
Example #2
0
def run_mst_generate(args):
    """
    generate skeleton in batch
    :param args: input folder path and data folder path
    """
    test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'),
                           dtype=np.int)
    root_select_model = ROOTNET()
    root_select_model.to(device)
    root_select_model.eval()
    root_checkpoint = torch.load(args.rootnet)
    root_select_model.load_state_dict(root_checkpoint['state_dict'])
    connectivity_model = PairCls()
    connectivity_model.to(device)
    connectivity_model.eval()
    conn_checkpoint = torch.load(args.bonenet)
    connectivity_model.load_state_dict(conn_checkpoint['state_dict'])

    for model_id in test_list:
        print(model_id)
        pred_joints, vox = predict_joints(model_id, args)
        mesh_filename = os.path.join(args.dataset_folder,
                                     'obj_remesh/{:d}.obj'.format(model_id))
        mesh = o3d.io.read_triangle_mesh(mesh_filename)
        surface_geodesic = calc_surface_geodesic(mesh)
        data = create_single_data(mesh, vox, surface_geodesic, pred_joints)
        root_id = getInitId(data, root_select_model)
        with torch.no_grad():
            cost_matrix, _ = connectivity_model.forward(data)
            connect_prob = torch.sigmoid(cost_matrix)
        pair_idx = data.pairs.long().data.cpu().numpy()
        cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
        cost_matrix[pair_idx[:, 0],
                    pair_idx[:,
                             1]] = connect_prob.data.cpu().numpy().squeeze()
        cost_matrix = cost_matrix + cost_matrix.transpose()
        cost_matrix = -np.log(cost_matrix + 1e-10)
        #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix)
        cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints,
                                                     vox)

        skel = Skel()
        parent, key, root_id = primMST_symmetry(cost_matrix, root_id,
                                                pred_joints)
        for i in range(len(parent)):
            if parent[i] == -1:
                skel.root = TreeNode('root', tuple(pred_joints[i]))
                break
        loadSkel_recur(skel.root, i, None, pred_joints, parent)
        img = show_obj_skel(mesh_filename, skel.root)
        cv2.imwrite(
            os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)),
            img[:, :, ::-1])
        skel.save(
            os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))