Ejemplo n.º 1
0
def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename):
    """
    Predict skeleton structure based on joints
    :param input_data: wrapped data
    :param vox: voxelized mesh
    :param root_pred_net: network to predict root
    :param bone_pred_net: network to predict pairwise connectivity cost
    :param mesh_filename: meshfilename for debugging
    :return: predicted skeleton structure
    """
    root_id = getInitId(input_data, root_pred_net)
    pred_joints = input_data.y[:input_data.num_joint[0]].data.cpu().numpy()

    with torch.no_grad():
        connect_prob, _ = bone_pred_net(input_data)
        connect_prob = torch.sigmoid(connect_prob)
    pair_idx = input_data.pairs.long().data.cpu().numpy()
    prob_matrix = np.zeros((data.num_joint[0], data.num_joint[0]))
    prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze()
    prob_matrix = prob_matrix + prob_matrix.transpose()
    cost_matrix = -np.log(prob_matrix + 1e-10)
    cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox)

    pred_skel = Info()
    parent, key = primMST_symmetry(cost_matrix, root_id, pred_joints)
    for i in range(len(parent)):
        if parent[i] == -1:
            pred_skel.root = TreeNode('root', tuple(pred_joints[i]))
            break
    loadSkel_recur(pred_skel.root, i, None, pred_joints, parent)
    pred_skel.joint_pos = pred_skel.get_joint_dict()
    #show_mesh_vox(mesh_filename, vox, pred_skel.root)
    img = show_obj_skel(mesh_filename, pred_skel.root)
    return pred_skel
Ejemplo n.º 2
0
def genDataset(process_id):
    global dataset_folder
    print("process ID {:d}".format(process_id))
    if process_id < 6:
        model_list = np.loadtxt(os.path.join(dataset_folder,
                                             'train_final.txt'),
                                dtype=int)
        model_list = model_list[365 * process_id:365 * (process_id + 1)]
        split_name = 'train'
    elif process_id == 6:
        model_list = np.loadtxt(os.path.join(dataset_folder, 'val_final.txt'),
                                dtype=int)
        split_name = 'val'
    elif process_id == 7:
        model_list = np.loadtxt(os.path.join(dataset_folder, 'test_final.txt'),
                                dtype=int)
        split_name = 'test'

    mkdir_p(os.path.join(dataset_folder, split_name))
    for model_id in model_list:
        remeshed_obj_filename = os.path.join(
            dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id))
        info_filename = os.path.join(
            dataset_folder, 'rig_info_remesh/{:d}.txt'.format(model_id))
        remeshed_obj = o3d.io.read_triangle_mesh(remeshed_obj_filename)
        remesh_obj_v = np.asarray(remeshed_obj.vertices)
        remesh_obj_vn = np.asarray(remeshed_obj.vertex_normals)
        remesh_obj_f = np.asarray(remeshed_obj.triangles)
        rig_info = Info(info_filename)

        #vertices
        vert_filename = os.path.join(
            dataset_folder, '{:s}/{:d}_v.txt'.format(split_name, model_id))
        input_feature = np.concatenate((remesh_obj_v, remesh_obj_vn), axis=1)
        np.savetxt(vert_filename, input_feature, fmt='%.6f')

        #topology edges
        edge_index = get_tpl_edges(remesh_obj_v, remesh_obj_f)
        graph_filename = os.path.join(
            dataset_folder, '{:s}/{:d}_tpl_e.txt'.format(split_name, model_id))
        np.savetxt(graph_filename, edge_index, fmt='%d')

        # geodesic_edges
        surface_geodesic = calc_surface_geodesic(remeshed_obj)
        edge_index = get_geo_edges(surface_geodesic, remesh_obj_v)
        graph_filename = os.path.join(
            dataset_folder, '{:s}/{:d}_geo_e.txt'.format(split_name, model_id))
        np.savetxt(graph_filename, edge_index, fmt='%d')

        # joints
        joint_pos = rig_info.get_joint_dict()
        joint_name_list = list(joint_pos.keys())
        joint_pos_list = list(joint_pos.values())
        joint_pos_list = [np.array(i) for i in joint_pos_list]
        adjacent_matrix = rig_info.adjacent_matrix()
        joint_filename = os.path.join(
            dataset_folder, '{:s}/{:d}_j.txt'.format(split_name, model_id))
        adj_filename = os.path.join(
            dataset_folder, '{:s}/{:d}_adj.txt'.format(split_name, model_id))
        np.savetxt(adj_filename, adjacent_matrix, fmt='%d')
        np.savetxt(joint_filename, np.array(joint_pos_list), fmt='%.6f')

        # pre_trained attn
        shutil.copyfile(
            os.path.join(dataset_folder,
                         'pretrain_attention/{:d}.txt'.format(model_id)),
            os.path.join(dataset_folder,
                         '{:s}/{:d}_attn.txt'.format(split_name, model_id)))

        # voxel
        shutil.copyfile(
            os.path.join(dataset_folder, 'vox/{:d}.binvox'.format(model_id)),
            os.path.join(dataset_folder,
                         '{:s}/{:d}.binvox'.format(split_name, model_id)))

        #skinning information
        num_nearest_bone = 5
        geo_dist = np.load(
            os.path.join(
                dataset_folder,
                "volumetric_geodesic/{:d}_volumetric_geo.npy".format(
                    model_id)))
        bone_pos, bone_names, bone_isleaf = get_bones(rig_info)

        input_samples = []  # mesh_vertex_id, (bone_id, 1 / D_g, is_leaf) * N
        ground_truth_labels = []  # w_1, w_2, ..., w_N
        for vert_remesh_id in range(len(remesh_obj_v)):
            this_sample = [vert_remesh_id]
            this_label = []
            skin = rig_info.joint_skin[vert_remesh_id]
            skin_w = {}
            for i in np.arange(1, len(skin), 2):
                skin_w[skin[i]] = float(skin[i + 1])
            bone_id_near_to_far = np.argsort(geo_dist[vert_remesh_id, :])
            for i in range(num_nearest_bone):
                if i >= len(bone_id_near_to_far):
                    this_sample += [-1, 0, 0]
                    this_label.append(0.0)
                    continue
                bone_id = bone_id_near_to_far[i]
                this_sample.append(bone_id)
                this_sample.append(1.0 /
                                   (geo_dist[vert_remesh_id, bone_id] + 1e-10))
                this_sample.append(bone_isleaf[bone_id])
                start_joint_name = bone_names[bone_id][0]
                if start_joint_name in skin_w:
                    this_label.append(skin_w[start_joint_name])
                    del skin_w[start_joint_name]
                else:
                    this_label.append(0.0)

            input_samples.append(this_sample)
            ground_truth_labels.append(this_label)

        with open(
                os.path.join(dataset_folder, '{:s}/{:d}_skin.txt').format(
                    split_name, model_id), 'w') as fout:
            for i in range(len(bone_pos)):
                fout.write('bones {:s} {:s} {:.6f} {:.6f} {:.6f} '
                           '{:.6f} {:.6f} {:.6f}\n'.format(
                               bone_names[i][0], bone_names[i][1],
                               bone_pos[i, 0], bone_pos[i, 1], bone_pos[i, 2],
                               bone_pos[i, 3], bone_pos[i, 4], bone_pos[i, 5]))
            for i in range(len(input_samples)):
                fout.write('bind {:d} '.format(input_samples[i][0]))
                for j in np.arange(1, len(input_samples[i]), 3):
                    fout.write('{:d} {:.6f} {:d} '.format(
                        input_samples[i][j], input_samples[i][j + 1],
                        input_samples[i][j + 2]))
                fout.write('\n')
            for i in range(len(ground_truth_labels)):
                fout.write('influence ')
                for j in range(len(ground_truth_labels[i])):
                    fout.write('{:.3f} '.format(ground_truth_labels[i][j]))
                fout.write('\n')