示例#1
0
def refine_pose(pose, refiner):
    # type: (Sequence[Sequence[float]], Refiner) -> Optional[List[np.ndarray]]
    """
    :param pose: list of joints where each joint is in the form [jtype, x3d, y3d, z3d]
    :param refiner: pose refiner model
    :return: refined pose -> list of 14 ordered joints where each joint is in the form [x3d, y3d, z3d]
        >> see `Joint.NAMES` for joint order
    """

    # convert `pose` list into a `Pose` object
    joints = []
    for jtype in range(14):
        _joint = [j for j in pose if j[0] == jtype]
        if len(_joint) == 1:
            _, x, y, z = _joint[0][0], _joint[0][1], _joint[0][2], _joint[0][3]
            joint = np.array([-1, -1, jtype, -1, -1, x, y, z, 0, 0])
            joint = Joint(joint)
            joints.append(joint)
        else:
            joint = np.array([-1, -1, jtype, -1, -1, -1, -1, -1, 1, 1])
            joint = Joint(joint)
            joints.append(joint)
    pose = Pose(joints)

    # convert `Pose` object into a fountain
    rr_pose_pred = pose.to_rr_pose(MAX_LS)
    for jtype in range(1, 14):
        if not pose[jtype].visible:
            rr_pose_pred[jtype - 1] = np.array([-1, -1, -1])
    rr_pose_pred = torch.tensor(rr_pose_pred).unsqueeze(0).float()

    # refine fountain with `refiner` model
    refined_rr_pose_pred = refiner.forward(rr_pose_pred).numpy().squeeze()

    if pose[0].type == 0:
        refined_pose_pred = Pose.from_rr_pose(refined_rr_pose_pred,
                                              head_pos3d=pose[0].pos3d,
                                              max_ls=MAX_LS)
        return refined_pose_pred
    else:
        return None
    def refine(self,
               pose,
               confidences,
               hole_th,
               replace_th=0.4,
               head_pos3d=None):
        # type: (JointList, List[float], float, float, Tuple[float, float, float]) -> JointList
        """
        Refine `pose`, considering as missing all the joints whose confidence value
        is less than a given threshold `th`.
        :param pose: pose to refine
            >> it is list of 14 quadruples of type (jtype, x3d, y3d, z3d)
            >> the pose must always have 14 joints; to encode any "holes" use
            random coordinates with a confidence value <= 0
        :param confidences: confidence values of pose joints
            >> it is a list of 14 values such that `confidences[i]`
            is the confidence of the i-th joint of the pose
        :param hole_th: confidence threshold
        :param replace_th: replace a joint with its refined version only
            if its confidence is <= `replace_th`
        :param head_pos3d: 3D position of the head;
            >> if `None`, it is set to the first joint of `pose`
        :return: refined version of the input pose
        """
        from joint import Joint
        if head_pos3d is None:
            head_pos3d = pose[0][1:]

        input_pose = deepcopy(pose)

        # from coords to Pose
        joints = []
        for c in pose:
            jtype, x, y, z = c
            joint = np.array([-1, -1, jtype, -1, -1, x, y, z, -1, -1])
            joint = Joint(joint)
            joints.append(joint)
        pose = Pose(joints=joints)

        # from Pose to RR-Pose (root relative pose)
        rr_pose = deepcopy(pose).to_rr_pose(max_ls=np.array(utils.MAX_LS))
        for jtype in range(1, 14):
            # if the confidence of the joint is less then `th`
            # this joint is considered a hole
            if confidences[jtype] <= hole_th:
                rr_pose[jtype - 1] = np.array([-1, -1, -1])

        rr_pose = torch.tensor(rr_pose).unsqueeze(0)
        device = self.state_dict()['fc1.0.weight'].device
        rr_pose = rr_pose.to(device).float()

        # predict refined RR-Pose
        refined_rr_pose = self.forward(rr_pose)
        refined_rr_pose = refined_rr_pose.detach().cpu().numpy().squeeze()

        # from RR-Pose (with ref) to Pose (with ref)
        pose_ref = Pose.from_rr_pose(refined_rr_pose.copy(),
                                     head_pos3d=head_pos3d,
                                     max_ls=np.array(utils.MAX_LS))
        coords3d_pred_ref = []
        for jtype, c in enumerate(pose_ref):
            if confidences[jtype] > replace_th:
                coords3d_pred_ref.append(input_pose[jtype])
            else:
                coords3d_pred_ref.append((jtype, c[0], c[1], c[2]))

        return coords3d_pred_ref