def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None): """ Predict joints :param input_data: wrapped input data :param vox: voxelized mesh :param joint_pred_net: network for predicting joints :param threshold: density threshold to filter out shifted points :param bandwidth: bandwidth for meanshift clustering :param mesh_filename: mesh filename for visualization :return: wrapped data with predicted joints, pair-wise bone representation added. """ data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net( input_data) y_pred = data_displacement + input_data.pos y_pred_np = y_pred.data.cpu().numpy() attn_pred_np = attn_pred.data.cpu().numpy() y_pred_np, index_inside = inside_check(y_pred_np, vox) attn_pred_np = attn_pred_np[index_inside, :] y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3] attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3] # symmetrize points by reflecting y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]]) y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0) attn_pred_np = np.tile(attn_pred_np, (2, 1)) #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) if bandwidth is None: bandwidth = bandwidth_pred.item() y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40) #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) Y_dist = np.sum( ((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :])**2), axis=2) density = np.maximum(bandwidth**2 - Y_dist, np.zeros(Y_dist.shape)) density = np.sum(density, axis=0) density_sum = np.sum(density) y_pred_np = y_pred_np[density / density_sum > threshold] attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0] density = density[density / density_sum > threshold] #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) pred_joints = nms_meanshift(y_pred_np, density, bandwidth) pred_joints, _ = flip(pred_joints) #img = draw_shifted_pts(mesh_filename, pred_joints) # prepare and add new data members pairs = list(it.combinations(range(pred_joints.shape[0]), 2)) pair_attr = [] for pr in pairs: dist = np.linalg.norm(pred_joints[pr[0]] - pred_joints[pr[1]]) bone_samples = sample_on_bone(pred_joints[pr[0]], pred_joints[pr[1]]) bone_samples_inside, _ = inside_check(bone_samples, vox) outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10) attr = np.array([dist, outside_proportion, 1]) pair_attr.append(attr) pairs = np.array(pairs) pair_attr = np.array(pair_attr) pairs = torch.from_numpy(pairs).float() pair_attr = torch.from_numpy(pair_attr).float() pred_joints = torch.from_numpy(pred_joints).float() joints_batch = torch.zeros(len(pred_joints), dtype=torch.long) pairs_batch = torch.zeros(len(pairs), dtype=torch.long) input_data.joints = pred_joints input_data.pairs = pairs input_data.pair_attr = pair_attr input_data.joints_batch = joints_batch input_data.pairs_batch = pairs_batch return input_data
def predict_joints(input_data, vox, joint_pred_net, threshold, bandwidth=None, mesh_filename=None): """ Predict joints :param input_data: wrapped input data :param vox: voxelized mesh :param joint_pred_net: network for predicting joints :param threshold: density threshold to filter out shifted points :param bandwidth: bandwidth for meanshift clustering :param mesh_filename: mesh filename for visualization :return: wrapped data with predicted joints, pair-wise bone representation added. """ data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net( input_data) y_pred = data_displacement + input_data.pos y_pred_np = y_pred.data.cpu().numpy() attn_pred_np = attn_pred.data.cpu().numpy() y_pred_np, index_inside = inside_check(y_pred_np, vox) attn_pred_np = attn_pred_np[index_inside, :] y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3] attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3] # symmetrize points by reflecting y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]]) y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0) attn_pred_np = np.tile(attn_pred_np, (2, 1)) #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) if bandwidth is None: bandwidth = bandwidth_pred y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40) #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) Y_dist = np.sum( ((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :])**2), axis=2) density = np.maximum(bandwidth**2 - Y_dist, np.zeros(Y_dist.shape)) density = np.sum(density, axis=0) density_sum = np.sum(density) y_pred_np = y_pred_np[density / density_sum > threshold] attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0] density = density[density / density_sum > threshold] #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np) pred_joints = nms_meanshift(y_pred_np, density, bandwidth) pred_joints, _ = flip(pred_joints) #img = draw_shifted_pts(mesh_filename, pred_joints) # prepare and add new data members num_joint = len(pred_joints) pair_all = [] for joint1_id in range(len(pred_joints)): for joint2_id in range(joint1_id + 1, len(pred_joints)): dist = np.linalg.norm(pred_joints[joint1_id] - pred_joints[joint2_id]) bone_samples = sample_on_bone(pred_joints[joint1_id], pred_joints[joint2_id]) bone_samples_inside, _ = inside_check(bone_samples, vox) outside_proportion = len(bone_samples_inside) / ( len(bone_samples) + 1e-10) pair = np.array( [joint1_id, joint2_id, dist, outside_proportion, 1]) pair_all.append(pair) pair_all = np.array(pair_all) num_pair = len(pair_all) pair_all = torch.from_numpy(pair_all).float() if num_joint < len(input_data.pos): pred_joints = np.tile( pred_joints, (round(1.0 * len(input_data.pos) / num_joint + 0.5), 1)) pred_joints = pred_joints[:len(input_data.pos), :] pred_joints = torch.from_numpy(pred_joints).float() input_data.y = pred_joints input_data.num_joint = [num_joint] input_data.pairs = pair_all input_data.num_pair = [num_pair] return input_data