def tranfer_to_ori_mesh(filename_ori, filename_remesh, pred_rig): """ convert the predicted rig of remeshed model to the rig of the original model. Just assign skinning weight based on nearest neighbor :param filename_ori: original mesh filename :param filename_remesh: remeshed mesh filename :param pred_rig: predicted rig :return: predicted rig for original mesh """ mesh_remesh = o3d.io.read_triangle_mesh(filename_remesh) mesh_ori = o3d.io.read_triangle_mesh(filename_ori) tranfer_rig = Info() vert_remesh = np.asarray(mesh_remesh.vertices) vert_ori = np.asarray(mesh_ori.vertices) vertice_distance = np.sqrt( np.sum((vert_ori[np.newaxis, ...] - vert_remesh[:, np.newaxis, :])**2, axis=2)) vertice_raw_id = np.argmin( vertice_distance, axis=0 ) # nearest vertex id on the fixed mesh for each vertex on the remeshed mesh tranfer_rig.root = pred_rig.root tranfer_rig.joint_pos = pred_rig.joint_pos new_skin = [] for v in range(len(vert_ori)): skin_v = [v] v_nn = vertice_raw_id[v] skin_v += pred_rig.joint_skin[v_nn][1:] new_skin.append(skin_v) tranfer_rig.joint_skin = new_skin return tranfer_rig
def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename): """ Predict skeleton structure based on joints :param input_data: wrapped data :param vox: voxelized mesh :param root_pred_net: network to predict root :param bone_pred_net: network to predict pairwise connectivity cost :param mesh_filename: meshfilename for debugging :return: predicted skeleton structure """ root_id = getInitId(input_data, root_pred_net) pred_joints = input_data.y[:input_data.num_joint[0]].data.cpu().numpy() with torch.no_grad(): connect_prob, _ = bone_pred_net(input_data) connect_prob = torch.sigmoid(connect_prob) pair_idx = input_data.pairs.long().data.cpu().numpy() prob_matrix = np.zeros((data.num_joint[0], data.num_joint[0])) prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze() prob_matrix = prob_matrix + prob_matrix.transpose() cost_matrix = -np.log(prob_matrix + 1e-10) cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox) pred_skel = Info() parent, key = primMST_symmetry(cost_matrix, root_id, pred_joints) for i in range(len(parent)): if parent[i] == -1: pred_skel.root = TreeNode('root', tuple(pred_joints[i])) break loadSkel_recur(pred_skel.root, i, None, pred_joints, parent) pred_skel.joint_pos = pred_skel.get_joint_dict() #show_mesh_vox(mesh_filename, vox, pred_skel.root) img = show_obj_skel(mesh_filename, pred_skel.root) return pred_skel