Esempio n. 1
0
def predict_joints(model_id, args):
    """
    predict joints for a specified model
    :param model_id: processed model ID number
    :param args:
    :return: predicted joints, and voxelized mesh
    """
    vox_folder = os.path.join(args.dataset_folder, 'vox/')
    mesh_folder = os.path.join(args.dataset_folder, 'obj_remesh/')
    raw_pred = os.path.join(args.res_folder, '{:d}.ply'.format(model_id))
    vox_file = os.path.join(vox_folder, '{:d}.binvox'.format(model_id))
    mesh_file = os.path.join(mesh_folder, '{:d}.obj'.format(model_id))
    pred_attn = np.load(
        os.path.join(args.res_folder, '{:d}_attn.npy'.format(model_id)))

    with open(vox_file, 'rb') as fvox:
        vox = binvox_rw.read_as_3d_array(fvox)
    pred_joints = readPly(raw_pred)
    pred_joints, index_inside = inside_check(pred_joints, vox)
    pred_attn = pred_attn[index_inside, :]
    # img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)

    bandwidth = np.load(
        os.path.join(args.res_folder, '{:d}_bandwidth.npy'.format(model_id)))
    bandwidth = bandwidth[0]
    pred_joints = pred_joints[pred_attn.squeeze() > 1e-3]
    pred_attn = pred_attn[pred_attn.squeeze() > 1e-3]

    # reflect raw points
    pred_joints_reflect = pred_joints * np.array([[-1, 1, 1]])
    pred_joints = np.concatenate((pred_joints, pred_joints_reflect), axis=0)
    pred_attn = np.tile(pred_attn, (2, 1))
    # img = draw_shifted_pts(mesh_file, pred_joints, weights=pred_attn)
    # cv2.imwrite(os.path.join(res_folder, '{:s}_raw.jpg'.format(model_id)), img[:, :, ::-1])

    pred_joints = meanshift_cluster(pred_joints,
                                    bandwidth,
                                    pred_attn,
                                    max_iter=20)
    Y_dist = np.sum(
        ((pred_joints[np.newaxis, ...] - pred_joints[:, np.newaxis, :])**2),
        axis=2)
    density = np.maximum(bandwidth**2 - Y_dist, np.zeros(Y_dist.shape))
    # density = density * pred_attn
    density = np.sum(density, axis=0)
    density_sum = np.sum(density)
    pred_joints_ = pred_joints[density / density_sum > args.threshold_best]
    density_ = density[density / density_sum > args.threshold_best]
    pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
    pred_joints_, _ = flip(pred_joints_)

    reduce_threshold = args.threshold_best
    while len(pred_joints_) < 2 and reduce_threshold > 1e-7:
        # print('reducing')
        reduce_threshold = reduce_threshold / 1.3
        pred_joints_ = pred_joints[density / density_sum >= reduce_threshold]
        density_ = density[density / density_sum > reduce_threshold]
        pred_joints_ = nms_meanshift(pred_joints_, density_, bandwidth)
        pred_joints_, _ = flip(pred_joints_)
    if reduce_threshold <= 1e-7:
        pred_joints_ = nms_meanshift(pred_joints_, density, bandwidth)
        pred_joints_, _ = flip(pred_joints_)

    pred_joints = pred_joints_
    # img = draw_shifted_pts(mesh_file, pred_joints)
    # cv2.imwrite(os.path.join(res_folder, '{:d}_joint.jpg'.format(model_id)), img)
    # np.save(os.path.join(res_folder, '{:d}_joint.npy'.format(model_id)), pred_joints)
    return pred_joints, vox
Esempio n. 2
0
def predict_joints(input_data,
                   vox,
                   joint_pred_net,
                   threshold,
                   bandwidth=None,
                   mesh_filename=None):
    """
    Predict joints
    :param input_data: wrapped input data
    :param vox: voxelized mesh
    :param joint_pred_net: network for predicting joints
    :param threshold: density threshold to filter out shifted points
    :param bandwidth: bandwidth for meanshift clustering
    :param mesh_filename: mesh filename for visualization
    :return: wrapped data with predicted joints, pair-wise bone representation added.
    """
    data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net(
        input_data)
    y_pred = data_displacement + input_data.pos
    y_pred_np = y_pred.data.cpu().numpy()
    attn_pred_np = attn_pred.data.cpu().numpy()
    y_pred_np, index_inside = inside_check(y_pred_np, vox)
    attn_pred_np = attn_pred_np[index_inside, :]
    y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3]
    attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3]

    # symmetrize points by reflecting
    y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]])
    y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0)
    attn_pred_np = np.tile(attn_pred_np, (2, 1))

    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
    if bandwidth is None:
        bandwidth = bandwidth_pred.item()
    y_pred_np = meanshift_cluster(y_pred_np,
                                  bandwidth,
                                  attn_pred_np,
                                  max_iter=40)
    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)

    Y_dist = np.sum(
        ((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :])**2),
        axis=2)
    density = np.maximum(bandwidth**2 - Y_dist, np.zeros(Y_dist.shape))
    density = np.sum(density, axis=0)
    density_sum = np.sum(density)
    y_pred_np = y_pred_np[density / density_sum > threshold]
    attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0]
    density = density[density / density_sum > threshold]

    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
    pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
    pred_joints, _ = flip(pred_joints)
    #img = draw_shifted_pts(mesh_filename, pred_joints)

    # prepare and add new data members
    pairs = list(it.combinations(range(pred_joints.shape[0]), 2))
    pair_attr = []
    for pr in pairs:
        dist = np.linalg.norm(pred_joints[pr[0]] - pred_joints[pr[1]])
        bone_samples = sample_on_bone(pred_joints[pr[0]], pred_joints[pr[1]])
        bone_samples_inside, _ = inside_check(bone_samples, vox)
        outside_proportion = len(bone_samples_inside) / (len(bone_samples) +
                                                         1e-10)
        attr = np.array([dist, outside_proportion, 1])
        pair_attr.append(attr)
    pairs = np.array(pairs)
    pair_attr = np.array(pair_attr)
    pairs = torch.from_numpy(pairs).float()
    pair_attr = torch.from_numpy(pair_attr).float()
    pred_joints = torch.from_numpy(pred_joints).float()
    joints_batch = torch.zeros(len(pred_joints), dtype=torch.long)
    pairs_batch = torch.zeros(len(pairs), dtype=torch.long)

    input_data.joints = pred_joints
    input_data.pairs = pairs
    input_data.pair_attr = pair_attr
    input_data.joints_batch = joints_batch
    input_data.pairs_batch = pairs_batch
    return input_data
Esempio n. 3
0
def predict_joints(input_data,
                   vox,
                   joint_pred_net,
                   threshold,
                   bandwidth=None,
                   mesh_filename=None):
    """
    Predict joints
    :param input_data: wrapped input data
    :param vox: voxelized mesh
    :param joint_pred_net: network for predicting joints
    :param threshold: density threshold to filter out shifted points
    :param bandwidth: bandwidth for meanshift clustering
    :param mesh_filename: mesh filename for visualization
    :return: wrapped data with predicted joints, pair-wise bone representation added.
    """
    data_displacement, _, attn_pred, bandwidth_pred = joint_pred_net(
        input_data)
    y_pred = data_displacement + input_data.pos
    y_pred_np = y_pred.data.cpu().numpy()
    attn_pred_np = attn_pred.data.cpu().numpy()
    y_pred_np, index_inside = inside_check(y_pred_np, vox)
    attn_pred_np = attn_pred_np[index_inside, :]
    y_pred_np = y_pred_np[attn_pred_np.squeeze() > 1e-3]
    attn_pred_np = attn_pred_np[attn_pred_np.squeeze() > 1e-3]

    # symmetrize points by reflecting
    y_pred_np_reflect = y_pred_np * np.array([[-1, 1, 1]])
    y_pred_np = np.concatenate((y_pred_np, y_pred_np_reflect), axis=0)
    attn_pred_np = np.tile(attn_pred_np, (2, 1))

    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
    if bandwidth is None:
        bandwidth = bandwidth_pred
    y_pred_np = meanshift_cluster(y_pred_np,
                                  bandwidth,
                                  attn_pred_np,
                                  max_iter=40)
    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)

    Y_dist = np.sum(
        ((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :])**2),
        axis=2)
    density = np.maximum(bandwidth**2 - Y_dist, np.zeros(Y_dist.shape))
    density = np.sum(density, axis=0)
    density_sum = np.sum(density)
    y_pred_np = y_pred_np[density / density_sum > threshold]
    attn_pred_np = attn_pred_np[density / density_sum > threshold][:, 0]
    density = density[density / density_sum > threshold]

    #img = draw_shifted_pts(mesh_filename, y_pred_np, weights=attn_pred_np)
    pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
    pred_joints, _ = flip(pred_joints)
    #img = draw_shifted_pts(mesh_filename, pred_joints)

    # prepare and add new data members
    num_joint = len(pred_joints)
    pair_all = []
    for joint1_id in range(len(pred_joints)):
        for joint2_id in range(joint1_id + 1, len(pred_joints)):
            dist = np.linalg.norm(pred_joints[joint1_id] -
                                  pred_joints[joint2_id])
            bone_samples = sample_on_bone(pred_joints[joint1_id],
                                          pred_joints[joint2_id])
            bone_samples_inside, _ = inside_check(bone_samples, vox)
            outside_proportion = len(bone_samples_inside) / (
                len(bone_samples) + 1e-10)
            pair = np.array(
                [joint1_id, joint2_id, dist, outside_proportion, 1])
            pair_all.append(pair)
    pair_all = np.array(pair_all)
    num_pair = len(pair_all)
    pair_all = torch.from_numpy(pair_all).float()

    if num_joint < len(input_data.pos):
        pred_joints = np.tile(
            pred_joints,
            (round(1.0 * len(input_data.pos) / num_joint + 0.5), 1))
    pred_joints = pred_joints[:len(input_data.pos), :]
    pred_joints = torch.from_numpy(pred_joints).float()

    input_data.y = pred_joints
    input_data.num_joint = [num_joint]
    input_data.pairs = pair_all
    input_data.num_pair = [num_pair]
    return input_data