def predict_skeleton(input_data, vox, root_pred_net, bone_pred_net, mesh_filename): """ Predict skeleton structure based on joints :param input_data: wrapped data :param vox: voxelized mesh :param root_pred_net: network to predict root :param bone_pred_net: network to predict pairwise connectivity cost :param mesh_filename: meshfilename for debugging :return: predicted skeleton structure """ root_id = getInitId(input_data, root_pred_net) pred_joints = input_data.y[:input_data.num_joint[0]].data.cpu().numpy() with torch.no_grad(): connect_prob, _ = bone_pred_net(input_data) connect_prob = torch.sigmoid(connect_prob) pair_idx = input_data.pairs.long().data.cpu().numpy() prob_matrix = np.zeros((data.num_joint[0], data.num_joint[0])) prob_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze() prob_matrix = prob_matrix + prob_matrix.transpose() cost_matrix = -np.log(prob_matrix + 1e-10) cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox) pred_skel = Info() parent, key = primMST_symmetry(cost_matrix, root_id, pred_joints) for i in range(len(parent)): if parent[i] == -1: pred_skel.root = TreeNode('root', tuple(pred_joints[i])) break loadSkel_recur(pred_skel.root, i, None, pred_joints, parent) pred_skel.joint_pos = pred_skel.get_joint_dict() #show_mesh_vox(mesh_filename, vox, pred_skel.root) img = show_obj_skel(mesh_filename, pred_skel.root) return pred_skel
def run_mst_generate(args): """ generate skeleton in batch :param args: input folder path and data folder path """ test_list = np.loadtxt(os.path.join(args.dataset_folder, 'test_final.txt'), dtype=np.int) root_select_model = ROOTNET() root_select_model.to(device) root_select_model.eval() root_checkpoint = torch.load(args.rootnet) root_select_model.load_state_dict(root_checkpoint['state_dict']) connectivity_model = PairCls() connectivity_model.to(device) connectivity_model.eval() conn_checkpoint = torch.load(args.bonenet) connectivity_model.load_state_dict(conn_checkpoint['state_dict']) for model_id in test_list: print(model_id) pred_joints, vox = predict_joints(model_id, args) mesh_filename = os.path.join(args.dataset_folder, 'obj_remesh/{:d}.obj'.format(model_id)) mesh = o3d.io.read_triangle_mesh(mesh_filename) surface_geodesic = calc_surface_geodesic(mesh) data = create_single_data(mesh, vox, surface_geodesic, pred_joints) root_id = getInitId(data, root_select_model) with torch.no_grad(): cost_matrix, _ = connectivity_model.forward(data) connect_prob = torch.sigmoid(cost_matrix) pair_idx = data.pairs.long().data.cpu().numpy() cost_matrix = np.zeros((data.num_joint[0], data.num_joint[0])) cost_matrix[pair_idx[:, 0], pair_idx[:, 1]] = connect_prob.data.cpu().numpy().squeeze() cost_matrix = cost_matrix + cost_matrix.transpose() cost_matrix = -np.log(cost_matrix + 1e-10) #cost_matrix = flip_cost_matrix(pred_joints, cost_matrix) cost_matrix = increase_cost_for_outside_bone(cost_matrix, pred_joints, vox) skel = Skel() parent, key, root_id = primMST_symmetry(cost_matrix, root_id, pred_joints) for i in range(len(parent)): if parent[i] == -1: skel.root = TreeNode('root', tuple(pred_joints[i])) break loadSkel_recur(skel.root, i, None, pred_joints, parent) img = show_obj_skel(mesh_filename, skel.root) cv2.imwrite( os.path.join(args.res_folder, '{:d}_skel.jpg'.format(model_id)), img[:, :, ::-1]) skel.save( os.path.join(args.res_folder, '{:d}_skel.txt'.format(model_id)))