def get_joint_Hs(HR, R, T):
    th_full_hand_pose = HR.unsqueeze(0).mm(MANO.th_selected_comps)
    th_full_pose = torch.cat(
        [R.unsqueeze(0), MANO.th_hands_mean + th_full_hand_pose], 1)
    th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
    th_full_pose = th_full_pose.view(1, -1, 3)
    root_rot = rodrigues_layer.batch_rodrigues(th_full_pose[:,
                                                            0]).view(1, 3, 3)

    th_v_shaped = torch.matmul(MANO.th_shapedirs, MANO.th_betas.transpose(
        1, 0)).permute(2, 0, 1) + MANO.th_v_template
    th_j = torch.matmul(MANO.th_J_regressor, th_v_shaped).repeat(1, 1, 1)

    root_j = th_j[:, 0, :].contiguous().view(1, 3, 1)
    th_results = []
    th_results.append(th_with_zeros(torch.cat([root_rot, root_j], 2)))
    angle_parents = [
        4294967295, 0, 1, 2, 0, 4, 5, 0, 7, 8, 0, 10, 11, 0, 13, 14
    ]
    # Rotate each part
    for i in range(15):
        i_val = int(i + 1)
        joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val * 9].contiguous().view(
            1, 3, 3)
        joint_j = th_j[:, i_val, :].contiguous().view(1, 3, 1)
        parent = make_list(angle_parents)[i_val]
        parent_j = th_j[:, parent, :].contiguous().view(1, 3, 1)
        joint_rel_transform = th_with_zeros(
            torch.cat([joint_rot, joint_j - parent_j], 2))
        th_results.append(torch.matmul(th_results[parent],
                                       joint_rel_transform))

    Hs = torch.cat(th_results)
    Hs[:, :3, 3] = Hs[:, :3, 3] + T
    return Hs
Exemple #2
0
def th_posemap_axisang(pose_vectors):
    rot_nb = int(pose_vectors.shape[1] / 3)
    pose_vec_reshaped = pose_vectors.contiguous().view(-1, 3)
    rot_mats = rodrigues_layer.batch_rodrigues(pose_vec_reshaped)
    rot_mats = rot_mats.view(pose_vectors.shape[0], rot_nb * 9)
    pose_maps = subtract_flat_id(rot_mats)
    return pose_maps, rot_mats
def rotate_verts(verts, axisang=(0, 1, 1)):
    centroids = verts.mean(1).unsqueeze(1)
    verts_c = verts - centroids
    rot_mats = rodrigues_layer.batch_rodrigues(
        verts.new(axisang).unsqueeze(0)).view(1, 3, 3)
    verts_cr = rot_mats.repeat(verts.shape[0], 1,
                               1).bmm(verts_c.transpose(1, 2)).transpose(1, 2)
    verts_final = verts_cr + centroids
    return verts_final
def get_rot_map(HR, R):
    pose_representation = torch.cat((R, HR), 1)
    th_full_hand_pose = HR.mm(MANO.th_selected_comps)
    th_full_pose = torch.cat([R, MANO.th_hands_mean + th_full_hand_pose], 1)
    th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
    th_full_pose = th_full_pose.view(1, -1, 3)
    root_rot = rodrigues_layer.batch_rodrigues(th_full_pose[:,
                                                            0]).view(1, 3, 3)

    return th_pose_map, th_rot_map, root_rot
def th_posemap_axisang(pose_vectors):
    rot_nb = int(pose_vectors.shape[1] / 3)
    rot_mats = []
    for joint_idx in range(rot_nb - 1):
        joint_idx_val = joint_idx + 1
        axis_ang = pose_vectors[:, joint_idx_val * 3:(joint_idx_val + 1) * 3]
        rot_mat = rodrigues_layer.batch_rodrigues(axis_ang)
        rot_mats.append(rot_mat)

    # rot_mats = torch.stack(rot_mats, 1).view(-1, 15 *9)
    rot_mats = torch.cat(rot_mats, 1)
    pose_maps = subtract_flat_id(rot_mats)
    return pose_maps, rot_mats
Exemple #6
0
    def __init__(
        self,
        center_idx=None,
        flat_hand_mean=True,
        ncomps=6,
        side="right",
        mano_root="mano/models",
        use_pca=True,
        root_rot_mode="axisang",
        joint_rot_mode="axisang",
        robust_rot=False,
    ):
        """
        Args:
            center_idx: index of center joint in our computations,
                if -1 centers on estimate of palm as middle of base
                of middle finger and wrist
            flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
                flat hand, else match average hand pose
            mano_root: path to MANO pkl files for left and right hand
            ncomps: number of PCA components form pose space (<45)
            side: 'right' or 'left'
            use_pca: Use PCA decomposition for pose space.
            joint_rot_mode: 'axisang' or 'rotmat', ignored if use_pca
        """
        super().__init__()

        self.center_idx = center_idx
        self.robust_rot = robust_rot
        if root_rot_mode == "axisang":
            self.rot = 3
        else:
            self.rot = 6
        self.flat_hand_mean = flat_hand_mean
        self.side = side
        self.use_pca = use_pca
        self.joint_rot_mode = joint_rot_mode
        self.root_rot_mode = root_rot_mode
        if use_pca:
            self.ncomps = ncomps
        else:
            self.ncomps = 45

        if side == "right":
            self.mano_path = os.path.join(mano_root, "MANO_RIGHT.pkl")
        elif side == "left":
            self.mano_path = os.path.join(mano_root, "MANO_LEFT.pkl")

        smpl_data = ready_arguments(self.mano_path)

        hands_components = smpl_data["hands_components"]

        self.smpl_data = smpl_data

        self.register_buffer("th_betas",
                             torch.Tensor(smpl_data["betas"].r).unsqueeze(0))
        self.register_buffer("th_shapedirs",
                             torch.Tensor(smpl_data["shapedirs"].r))
        self.register_buffer("th_posedirs",
                             torch.Tensor(smpl_data["posedirs"].r))
        self.register_buffer(
            "th_v_template",
            torch.Tensor(smpl_data["v_template"].r).unsqueeze(0),
        )
        self.register_buffer(
            "th_J_regressor",
            torch.Tensor(np.array(smpl_data["J_regressor"].toarray())),
        )
        self.register_buffer("th_weights",
                             torch.Tensor(smpl_data["weights"].r))
        self.register_buffer(
            "th_faces",
            torch.Tensor(smpl_data["f"].astype(np.int32)).long())

        # Get hand mean
        hands_mean = (np.zeros(hands_components.shape[1])
                      if flat_hand_mean else smpl_data["hands_mean"])
        hands_mean = hands_mean.copy()
        th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
        if self.use_pca or self.joint_rot_mode == "axisang":
            # Save as axis-angle
            self.register_buffer("th_hands_mean", th_hands_mean)
            selected_components = hands_components[:ncomps]
            self.register_buffer("th_comps", torch.Tensor(hands_components))
            self.register_buffer("th_selected_comps",
                                 torch.Tensor(selected_components))
        else:
            th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
                th_hands_mean.view(15, 3)).reshape(15, 3, 3)
            self.register_buffer("th_hands_mean_rotmat", th_hands_mean_rotmat)

        # Kinematic chain params
        self.kintree_table = smpl_data["kintree_table"]
        parents = list(self.kintree_table[0].tolist())
        self.kintree_parents = parents
    def forward(self,
                sample,
                scaletrans=None,
                scale=None,
                trans=None,
                rotaxisang=None):
        """
        Args:
            scaletrans: torch.Tensor of shape [batch_size, channels] with channels == 6
                with in first position the predicted scale values and in 2,3 the 
                predicted translation values, and global rotation encoded as axis-angles
                in channel positions 4,5,6
        """
        if scaletrans is None:
            batch_size = scale.shape[0]
        else:
            batch_size = scaletrans.shape[0]
        if scale is None:
            scale = scaletrans[:, :1]
        if trans is None:
            trans = scaletrans[:, 1:3]
        if rotaxisang is None:
            rotaxisang = scaletrans[:, 3:]
        # Get rotation matrixes from axis-angles
        rotmat = rodrigues_layer.batch_rodrigues(rotaxisang).view(
            rotaxisang.shape[0], 3, 3)
        canobjverts = sample[BaseQueries.OBJCANVERTS].cuda()
        rotobjverts = rotmat.bmm(canobjverts.float().transpose(1,
                                                               2)).transpose(
                                                                   1, 2)

        final_trans = trans.unsqueeze(1) * self.trans_factor
        final_scale = scale.view(batch_size, 1, 1) * self.scale_factor
        height, width = tuple(sample[TransQueries.IMAGE].shape[2:])
        camintr = sample[TransQueries.CAMINTR].cuda()
        objverts3d, center3d = project.recover_3d_proj(rotobjverts,
                                                       camintr,
                                                       final_scale,
                                                       final_trans,
                                                       input_res=(width,
                                                                  height))
        # Recover 2D positions given camera intrinsic parameters and object vertex
        # coordinates in camera coordinate reference
        pred_objverts2d = camproject.batch_proj2d(objverts3d, camintr)
        if BaseQueries.OBJCORNERS3D in sample:
            canobjcorners = sample[BaseQueries.OBJCANCORNERS].cuda()
            rotobjcorners = rotmat.bmm(canobjcorners.float().transpose(
                1, 2)).transpose(1, 2)
            recov_objcorners3d = rotobjcorners + center3d
            pred_objcorners2d = camproject.batch_proj2d(
                rotobjcorners + center3d, camintr)
        else:
            pred_objcorners2d = None
            recov_objcorners3d = None
            rotobjcorners = None
        return {
            "obj_verts2d": pred_objverts2d,
            "obj_verts3d": rotobjverts,
            "recov_objverts3d": objverts3d,
            "recov_objcorners3d": recov_objcorners3d,
            "obj_scale": final_scale,
            "obj_prescale": scale,
            "obj_prerot": rotaxisang,
            "obj_trans": final_trans,
            "obj_pretrans": trans,
            "obj_corners2d": pred_objcorners2d,
            "obj_corners3d": rotobjcorners,
        }
Exemple #8
0
def optimize_hand(handfullpose, R, T, obj_verts, step=20):
    # Links between joints are:
    # Thumb: 0, 1, 2, 3, 4
    # Index: 0, 5, 6, 7, 8
    # Middle: 0, 9, 10, 11, 12
    # Fourth: 0, 13, 14, 15, 16
    # Small: 0, 17, 18, 19, 20

    # Thumb has two degrees of freedom when closing. 
    # Horizontal DOF is given by prediction and vertical DOF (Grasp direction) is given by following R:
    #Rthumb_vert = torch.FloatTensor([[ 0.8196, -0.4868, -0.3022],
    #                                [-0.1282, 0.3583, -0.9247],
    #                                [0.5585, 0.7966, 0.2313]]).cuda()
    #thumb_pred = rodrigues_layer.batch_rodrigues(handfullpose[:, 36:39]).view(3, 3)
    ## Limit the actual rotation because this was from open to close. Let's interpolate

    # SLERP between the quaternion of that RM and quat(1, 0, 0, 0) are:
    # 0.3:
    #array([[ 0.90412033, -0.31296735, -0.29089149],
   #[-0.01381291,  0.65903709, -0.75198359],
   #[ 0.42705459,  0.68390171,  0.59152585]])

    # 0.5:
    #array([[ 0.94920308, -0.20209229, -0.24118918],
   #[ 0.02896454,  0.8193583 , -0.57254959],
   #[ 0.31332821,  0.5364799 ,  0.78359093]])

    # 0.7:
    #array([[ 0.98125081, -0.10487662, -0.16170265],
   #[ 0.040975  ,  0.93332498, -0.35668689],
   #[ 0.18832923,  0.34337353,  0.92012321]])

    #Rmat_open = torch.matmul(torch.inverse(Rthumb_vert), thumb_pred)
    #Rmat_closed = torch.matmul(Rthumb_vert, thumb_pred)

    thumb_pred = rodrigues_layer.batch_rodrigues(handfullpose[:, 36:39]).view(3, 3)

    # NOTE: THIS IS NOT ACCURATE WHEN thumb_pred IS SUPER CLOSE FROM THE RESTING THUMB POSITION
    # NEITHER WHEN ROTATION (composition) ROTATION IS SUPER CLOSE .. "" ""
    # IN THAT CASE IT JUST DOESNT OPTIMIZES THUMB
    # TO SOLVE THIS - I SHOULD CHECK ANGLE FROM handfullpose[:, 36:39] AND SEE IF IT'S FURTHER THAN limit_bigfinger or too close

    #Let's get to double the rotation. We'll limit this to a maximum difference with respect root thumb position

    #Rmat_closed = torch.matmul(thumb_pred, thumb_pred)
    #Rmat_closed = torch.min(Rmat_closed, torch.FloatTensor([0.99999]).cuda())
    #Rmat_closed = torch.max(Rmat_closed, torch.FloatTensor([-0.99999]).cuda())

    # NOTE: Way 2 of doing it. Consider a maximum angle of rotation for thumb. We'll just 
    # Limit rotation matrix to it - converting to euler, scaling and back
    thumb_pred = torch.min(thumb_pred, torch.FloatTensor([0.99999]).cuda())
    thumb_pred = torch.max(thumb_pred, torch.FloatTensor([-0.99999]).cuda())
    eu = transforms3d.euler.mat2euler(thumb_pred.cpu().data.numpy())
    eu = np.array(eu)
    #eu = eu*2.0/np.sqrt((eu**2).sum())
    eu = eu*1.3043172907561458/np.sqrt((eu**2).sum())

    Rmat_closed = transforms3d.euler.euler2mat(eu[0], eu[1], eu[2])
    Rmat_closed = torch.FloatTensor(Rmat_closed).cuda()
    Rmat_closed = torch.min(Rmat_closed, torch.FloatTensor([0.99999]).cuda())
    Rmat_closed = torch.max(Rmat_closed, torch.FloatTensor([-0.99999]).cuda())

    # Limit was found by:
    # rotm = rodrigues_layer.batch_rodrigues(limit_bigfinger.unsqueeze(0)).view(3, 3)
    # angle = transforms3d.euler.mat2euler(rotm.cpu().data.numpy())
    # 1.3043172907561458 = np.sqrt((np.array(angle)**2).sum())

    #thumb_axisangle_open = rmat_to_axisangle(Rmat_open)
    thumb_axisangle_closed = rmat_to_axisangle(Rmat_closed)

    #handfullpose_open = handfullpose.clone()
    #handfullpose_open[0, 36:39] = torch.FloatTensor([0, 0, 0])
    ##handfullpose_open[0, 36:39] = thumb_axisangle_open #torch.FloatTensor([0, 0, 0])
    #handfullpose_open[0, 0:3] = torch.FloatTensor([0, 0, 0])
    #handfullpose_open[0, 9:12] = torch.FloatTensor([0, 0, 0])
    #handfullpose_open[0, 27:30] = torch.FloatTensor([0, 0, 0])
    #handfullpose_open[0, 18:21] = torch.FloatTensor([0, 0, 0])

    #handfullpose_closed = handfullpose.clone()
    #handfullpose_closed[0, 36:39] = thumb_axisangle_closed #limit_bigfinger_right
    #handfullpose_closed[0, 0:3] = limit_index_right
    #handfullpose_closed[0, 9:12] = limit_middlefinger_right
    #handfullpose_closed[0, 27:30] = limit_fourth_right
    #handfullpose_closed[0, 18:21] = limit_small_right

    #hand_verts_open, hand_joints_open = get_hand(handfullpose_open, R, T)
    #hand_verts_closed, hand_joints_closed = get_hand(handfullpose_closed, R, T)

    #knuckle_joints_open = hand_joints_open[0, [1, 5, 9, 13, 17]]
    #knuckle_joints_closed = hand_joints_closed[0, [1, 5, 9, 13, 17]]

    # OPTIMIZE EACH FINGER INDEPENDENTLY:
    handfullpose_converged = handfullpose.clone()
    touching_indexs = 0

    loss_distance = torch.FloatTensor([0]).cuda()
    loss_angle = torch.FloatTensor([0]).cuda()

    num_samples = 1000//step + 1
    inds = torch.linspace(0, 1, num_samples).cuda().unsqueeze(1)
    handfullpose_repeated = handfullpose_converged.clone().repeat(num_samples, 1)
    handfullpose_repeated[:, 36:39] = thumb_axisangle_closed.unsqueeze(0)*inds
    handfullpose_repeated[:, 0:3] = limit_index_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 9:12] = limit_middlefinger_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 27:30] = limit_fourth_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 18:21] = limit_small_right.unsqueeze(0)*inds

    meshes, _ = get_hand(handfullpose_repeated, R.repeat(num_samples, 1), T.repeat(num_samples, 1))

    relevant_verts_thumb = meshes[:, bigfinger_vertices]
    relevant_verts_index = meshes[:, indexfinger_vertices]
    relevant_verts_middle = meshes[:, middlefinger_vertices]
    relevant_verts_fourth = meshes[:, fourthfinger_vertices]
    relevant_verts_small = meshes[:, smallfinger_vertices]

    # Thumb: 
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_thumb, obj_verts)
    loss_distance = loss_distance + distance_to_minimize.mean()
    if converged:
        handfullpose_converged[0, 36:39] = thumb_axisangle_closed*inds[vertex_solution]
        loss_angle = loss_angle + (handfullpose[0, 36:39] - thumb_axisangle_closed*inds[vertex_solution])**2
        touching_indexs += 1
    else:
        handfullpose_converged[0, 36:39] = thumb_axisangle_closed

    # Index:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_index, obj_verts)
    loss_distance = loss_distance + distance_to_minimize.mean()
    if converged:
        handfullpose_converged[0, 0:3] = limit_index_right*inds[vertex_solution]
        loss_angle = loss_angle + (handfullpose[0, 0:3] - limit_index_right*inds[vertex_solution])**2
        touching_indexs += 1
    else:
        handfullpose_converged[0, 0:3] = limit_index_right

    # Middle:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_middle, obj_verts)
    loss_distance = loss_distance + distance_to_minimize.mean()
    if converged:
        handfullpose_converged[0, 9:12] = limit_middlefinger_right*inds[vertex_solution]
        loss_angle = loss_angle + (handfullpose[0, 9:12] - limit_middlefinger_right*inds[vertex_solution])**2
        touching_indexs += 1
    else:
        handfullpose_converged[0, 9:12] = limit_middlefinger_right

    # Fourth:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_fourth, obj_verts)
    loss_distance = loss_distance + distance_to_minimize.mean()
    if converged:
        handfullpose_converged[0, 27:30] = limit_fourth_right*inds[vertex_solution]
        loss_angle = loss_angle + (handfullpose[0, 27:30] - limit_fourth_right*inds[vertex_solution])**2
        touching_indexs += 1
    else:
        handfullpose_converged[0, 27:30] = limit_fourth_right

    # Small:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_small, obj_verts)
    loss_distance = loss_distance + distance_to_minimize.mean()
    if converged:
        handfullpose_converged[0, 18:21] = limit_small_right*inds[vertex_solution]
        loss_angle = loss_angle + (handfullpose[0, 18:21] - limit_small_right*inds[vertex_solution])**2
        touching_indexs += 1
    else:
        handfullpose_converged[0, 18:21] = limit_small_right

    #####                           #####
    ##### SECOND JOINT OPTIMIZATION #####
    #####                           #####

    #from IPython import embed
    #embed()

    loss_angle_secondjoint = torch.FloatTensor([0]).cuda()
    touching_indexs = 0

    num_samples = 1000//step + 1
    inds = torch.linspace(0, 1, num_samples).cuda().unsqueeze(1)
    handfullpose_repeated = handfullpose_converged.clone().repeat(num_samples, 1)
    handfullpose_repeated[:, 39:42] = limit_secondjoint_bigfinger_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 3:6] = limit_secondjoint_index_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 12:15] = limit_secondjoint_middlefinger_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 30:33] = limit_secondjoint_fourth_right.unsqueeze(0)*inds
    handfullpose_repeated[:, 21:24] = limit_secondjoint_small_right.unsqueeze(0)*inds

    meshes, _ = get_hand(handfullpose_repeated, R.repeat(num_samples, 1), T.repeat(num_samples, 1))

    relevant_verts_secondjoint_thumb = meshes[:, bigfinger_secondjoint_vertices]
    relevant_verts_secondjoint_index = meshes[:, indexfinger_secondjoint_vertices]
    relevant_verts_secondjoint_middle = meshes[:, middlefinger_secondjoint_vertices]
    relevant_verts_secondjoint_fourth = meshes[:, fourthfinger_secondjoint_vertices]
    relevant_verts_secondjoint_small = meshes[:, smallfinger_secondjoint_vertices]

    # Thumb: 
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_secondjoint_thumb, obj_verts)
    if converged:
        handfullpose_converged[0, 39:42] = limit_secondjoint_bigfinger_right*inds[vertex_solution]
        loss_angle_secondjoint = loss_angle_secondjoint + (handfullpose[0, 39:42] - limit_secondjoint_bigfinger_right*inds[vertex_solution])**2
        touching_indexs += 1

    # Index:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_secondjoint_index, obj_verts)
    if converged:
        handfullpose_converged[0, 3:6] = limit_secondjoint_index_right*inds[vertex_solution]
        loss_angle_secondjoint = loss_angle_secondjoint + (handfullpose[0, 3:6] - limit_secondjoint_index_right*inds[vertex_solution])**2
        touching_indexs += 1

    # Middle:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_secondjoint_middle, obj_verts)
    if converged:
        handfullpose_converged[0, 12:15] = limit_secondjoint_middlefinger_right*inds[vertex_solution]
        loss_angle_secondjoint = loss_angle_secondjoint + (handfullpose[0, 12:15] - limit_secondjoint_middlefinger_right*inds[vertex_solution])**2
        touching_indexs += 1

    # Fourth:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_secondjoint_fourth, obj_verts)
    if converged:
        handfullpose_converged[0, 30:33] = limit_secondjoint_fourth_right*inds[vertex_solution]
        loss_angle_secondjoint = loss_angle_secondjoint + (handfullpose[0, 30:33] - limit_secondjoint_fourth_right*inds[vertex_solution])**2
        touching_indexs += 1

    # Small:
    distance_to_minimize, vertex_solution, converged = get_optimization_angle(relevant_verts_secondjoint_small, obj_verts)
    if converged:
        handfullpose_converged[0, 21:24] = limit_secondjoint_small_right*inds[vertex_solution]
        loss_angle_secondjoint = loss_angle_secondjoint + (handfullpose[0, 21:24] - limit_secondjoint_small_right*inds[vertex_solution])**2
        touching_indexs += 1

    return handfullpose_converged, touching_indexs, loss_distance, loss_angle.mean(), loss_angle_secondjoint.mean()
Exemple #9
0
def get_hand(th_full_hand_pose, R, th_trans):
    batch_size = len(th_trans)
    th_full_pose = torch.cat([
        R,
        MANO.th_hands_mean + th_full_hand_pose
    ], 1)
    th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
    th_full_pose = th_full_pose.view(batch_size, -1, 3)
    root_rot = rodrigues_layer.batch_rodrigues(th_full_pose[:, 0]).view(batch_size, 3, 3)

    # NOTE: WITH DEFAULT HAND SHAPE PARAMETERS:
    # th_v_shape is [batchsize, 778, 3] -> For baseline hand position
    th_v_shaped = torch.matmul(MANO.th_shapedirs,
                               MANO.th_betas.transpose(1, 0)).permute(
                                   2, 0, 1) + MANO.th_v_template
    th_j = torch.matmul(MANO.th_J_regressor, th_v_shaped).repeat(
        batch_size, 1, 1)

    # NOTE: GET HAND MESH VERTICES: 778 vertices in 3D
    # th_v_posed -> [batchsize, 778, 3]
    th_v_posed = th_v_shaped + torch.matmul(
        MANO.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
    # Final T pose with transformation done !

    # Global rigid transformation
    th_results = []

    root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
    th_results.append(th_with_zeros(torch.cat([root_rot, root_j], 2)))

    # Rotate each part
    for i in range(15):
        i_val = int(i + 1)
        joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val *
                               9].contiguous().view(batch_size, 3, 3)
        joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
        parent = make_list(MANO.kintree_parents)[i_val]
        parent_j = th_j[:, parent, :].contiguous().view(batch_size, 3, 1)
        joint_rel_transform = th_with_zeros(
            torch.cat([joint_rot, joint_j - parent_j], 2))
        th_results.append(
            torch.matmul(th_results[parent], joint_rel_transform))
    th_results_global = th_results

    th_results2 = torch.zeros((batch_size, 4, 4, 16),
                              dtype=root_j.dtype,
                              device=root_j.device)
    for i in range(16):
        padd_zero = torch.zeros(1, dtype=th_j.dtype, device=th_j.device)
        joint_j = torch.cat(
            [th_j[:, i],
             padd_zero.view(1, 1).repeat(batch_size, 1)], 1)
        tmp = torch.bmm(th_results[i], joint_j.unsqueeze(2))
        th_results2[:, :, :, i] = th_results[i] - th_pack(tmp)

    th_T = torch.matmul(th_results2, MANO.th_weights.transpose(0, 1))

    th_rest_shape_h = torch.cat([
        th_v_posed.transpose(2, 1),
        torch.ones((batch_size, 1, th_v_posed.shape[1]),
                   dtype=th_T.dtype,
                   device=th_T.device),
    ], 1)

    th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
    th_verts = th_verts[:, :, :3]
    th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
    # In addition to MANO reference joints we sample vertices on each finger
    # to serve as finger tips
    if MANO.side == 'right':
        tips = torch.stack([
            th_verts[:, 745], th_verts[:, 317], th_verts[:, 444],
            th_verts[:, 556], th_verts[:, 673]
        ],
                           dim=1)
    else:
        tips = torch.stack([
            th_verts[:, 745], th_verts[:, 317], th_verts[:, 445],
            th_verts[:, 556], th_verts[:, 673]
        ],
                           dim=1)
    th_jtr = torch.cat([th_jtr, tips], 1)

    # Reorder joints to match visualization utilities
    th_jtr = torch.stack([
        th_jtr[:, 0], th_jtr[:, 13], th_jtr[:, 14], th_jtr[:, 15],
        th_jtr[:, 16], th_jtr[:, 1], th_jtr[:, 2], th_jtr[:, 3],
        th_jtr[:, 17], th_jtr[:, 4], th_jtr[:, 5], th_jtr[:, 6],
        th_jtr[:, 18], th_jtr[:, 10], th_jtr[:, 11], th_jtr[:, 12],
        th_jtr[:, 19], th_jtr[:, 7], th_jtr[:, 8], th_jtr[:, 9],
        th_jtr[:, 20]
    ],
                         dim=1)

    if th_trans is None or bool(torch.norm(th_trans) == 0):
        if MANO.center_idx is not None:
            center_joint = th_jtr[:, MANO.center_idx].unsqueeze(1)
            th_jtr = th_jtr - center_joint
            th_verts = th_verts - center_joint
    else:
        th_jtr = th_jtr + th_trans.unsqueeze(1)
        th_verts = th_verts + th_trans.unsqueeze(1)

    return th_verts, th_jtr
    def forward(self,
                th_pose_coeffs,
                th_betas=torch.zeros(1),
                th_trans=torch.zeros(1),
                root_palm=torch.Tensor([0])):
        """
        Args:
        th_trans (Tensor (batch_size x ncomps)): if provided, applies trans to joints and vertices
        th_betas (Tensor (batch_size x 10)): if provided, uses given shape parameters for hand shape
        else centers on root joint (9th joint)
        root_palm: return palm as hand root instead of wrist
        """
        # if len(th_pose_coeffs) == 0:
        #     return th_pose_coeffs.new_empty(0), th_pose_coeffs.new_empty(0)

        batch_size = th_pose_coeffs.shape[0]
        # Get axis angle from PCA components and coefficients
        if self.use_pca:
            th_hand_pose_coeffs = th_pose_coeffs[:, self.rot:self.rot +
                                                 self.ncomps]
            th_full_hand_pose = th_hand_pose_coeffs.mm(self.th_selected_comps)
            th_full_pose = torch.cat([
                th_pose_coeffs[:, :self.rot],
                self.th_hands_mean + th_full_hand_pose
            ], 1)
            th_pose_map, th_rot_map = th_posemap_axisang(th_full_pose)
            th_full_pose = th_full_pose.view(batch_size, -1, 3)
            root_rot = rodrigues_layer.batch_rodrigues(
                th_full_pose[:, 0]).view(batch_size, 3, 3)
        else:
            assert th_pose_coeffs.dim() == 4, (
                'When not self.use_pca, '
                'th_pose_coeffs should have 4 dims, got {}'.format(
                    th_pose_coeffs.dim()))
            assert th_pose_coeffs.shape[2:4] == (3, 3), (
                'When not self.use_pca, th_pose_coeffs have 3x3 matrix for two'
                'last dims, got {}'.format(th_pose_coeffs.shape[2:4]))
            th_pose_rots = rotproj.batch_rotprojs(th_pose_coeffs)
            th_rot_map = th_pose_rots[:, 1:].view(batch_size, -1)
            th_pose_map = subtract_flat_id(th_rot_map)
            root_rot = th_pose_rots[:, 0]

        # Full axis angle representation with root joint
        if th_betas is None or bool(torch.norm(th_betas) == 0):
            th_v_shaped = torch.matmul(self.th_shapedirs,
                                       self.th_betas.transpose(1, 0)).permute(
                                           2, 0, 1) + self.th_v_template
            th_j = torch.matmul(self.th_J_regressor,
                                th_v_shaped).repeat(batch_size, 1, 1)

        else:
            th_v_shaped = torch.matmul(self.th_shapedirs,
                                       th_betas.transpose(1, 0)).permute(
                                           2, 0, 1) + self.th_v_template
            th_j = torch.matmul(self.th_J_regressor, th_v_shaped)
            # th_pose_map should have shape 20x135

        th_v_posed = th_v_shaped + torch.matmul(
            self.th_posedirs, th_pose_map.transpose(0, 1)).permute(2, 0, 1)
        # Final T pose with transformation done !

        # Global rigid transformation
        th_results = []

        root_j = th_j[:, 0, :].contiguous().view(batch_size, 3, 1)
        th_results.append(th_with_zeros(torch.cat([root_rot, root_j], 2)))

        # Rotate each part
        for i in range(15):
            i_val = int(i + 1)
            joint_rot = th_rot_map[:, (i_val - 1) * 9:i_val *
                                   9].contiguous().view(batch_size, 3, 3)
            joint_j = th_j[:, i_val, :].contiguous().view(batch_size, 3, 1)
            parent = make_list(self.kintree_parents)[i_val]
            parent_j = th_j[:, parent, :].contiguous().view(batch_size, 3, 1)
            joint_rel_transform = th_with_zeros(
                torch.cat([joint_rot, joint_j - parent_j], 2))
            th_results.append(
                torch.matmul(th_results[parent], joint_rel_transform))
        th_results_global = th_results

        th_results2 = torch.zeros((batch_size, 4, 4, 16),
                                  dtype=root_j.dtype,
                                  device=root_j.device)

        for i in range(16):
            padd_zero = torch.zeros(1, dtype=th_j.dtype, device=th_j.device)
            joint_j = torch.cat(
                [th_j[:, i],
                 padd_zero.view(1, 1).repeat(batch_size, 1)], 1)
            tmp = torch.bmm(th_results[i], joint_j.unsqueeze(2))
            th_results2[:, :, :, i] = th_results[i] - th_pack(tmp)

        th_T = torch.matmul(th_results2, self.th_weights.transpose(0, 1))

        th_rest_shape_h = torch.cat([
            th_v_posed.transpose(2, 1),
            torch.ones((batch_size, 1, th_v_posed.shape[1]),
                       dtype=th_T.dtype,
                       device=th_T.device),
        ], 1)

        th_verts = (th_T * th_rest_shape_h.unsqueeze(1)).sum(2).transpose(2, 1)
        th_verts = th_verts[:, :, :3]
        th_jtr = torch.stack(th_results_global, dim=1)[:, :, :3, 3]
        # In addition to MANO reference joints we sample vertices on each finger
        # to serve as finger tips
        if self.side == 'right':
            tips = torch.stack([
                th_verts[:, 745], th_verts[:, 317], th_verts[:, 444],
                th_verts[:, 556], th_verts[:, 673]
            ],
                               dim=1)
        else:
            tips = torch.stack([
                th_verts[:, 745], th_verts[:, 317], th_verts[:, 445],
                th_verts[:, 556], th_verts[:, 673]
            ],
                               dim=1)
        if bool(root_palm):
            palm = (th_verts[:, 95] + th_verts[:, 22]).unsqueeze(1) / 2
            th_jtr = torch.cat([palm, th_jtr[:, 1:]], 1)
        th_jtr = torch.cat([th_jtr, tips], 1)

        # Reorder joints to match visualization utilities
        th_jtr = torch.stack([
            th_jtr[:, 0], th_jtr[:, 13], th_jtr[:, 14], th_jtr[:, 15],
            th_jtr[:, 16], th_jtr[:, 1], th_jtr[:, 2], th_jtr[:, 3],
            th_jtr[:, 17], th_jtr[:, 4], th_jtr[:, 5], th_jtr[:, 6],
            th_jtr[:, 18], th_jtr[:, 10], th_jtr[:, 11], th_jtr[:, 12],
            th_jtr[:, 19], th_jtr[:, 7], th_jtr[:, 8], th_jtr[:, 9], th_jtr[:,
                                                                            20]
        ],
                             dim=1)

        if th_trans is None or bool(torch.norm(th_trans) == 0):
            if self.center_idx is not None:
                center_joint = th_jtr[:, self.center_idx].unsqueeze(1)
                th_jtr = th_jtr - center_joint
                th_verts = th_verts - center_joint
        else:
            th_jtr = th_jtr + th_trans.unsqueeze(1)
            th_verts = th_verts + th_trans.unsqueeze(1)

        # Scale to milimeters
        th_verts = th_verts * 1000
        th_jtr = th_jtr * 1000
        return th_verts, th_jtr
    def __init__(self,
                 center_idx=None,
                 flat_hand_mean=True,
                 ncomps=6,
                 side='right',
                 mano_root='mano/models',
                 use_pca=True):
        """
        Args:
            center_idx: index of center joint in our computations,
                if -1 centers on estimate of palm as middle of base
                of middle finger and wrist
            flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
                flat hand, else match average hand pose
            mano_root: path to MANO pkl files for left and right hand
            ncomps: number of PCA components form pose space (<45)
            side: 'right' or 'left'
            use_pca: Use PCA decomposition for pose space.
        """
        super().__init__()

        self.center_idx = center_idx
        self.rot = 3
        self.flat_hand_mean = flat_hand_mean
        self.side = side
        self.use_pca = use_pca
        if use_pca:
            self.ncomps = ncomps
        else:
            self.ncomps = 45

        if side == 'right':
            self.mano_path = os.path.join(mano_root, 'MANO_RIGHT.pkl')
        elif side == 'left':
            self.mano_path = os.path.join(mano_root, 'MANO_LEFT.pkl')

        smpl_data = ready_arguments(self.mano_path)

        hands_components = smpl_data['hands_components']

        self.smpl_data = smpl_data

        self.register_buffer('th_betas',
                             torch.Tensor(smpl_data['betas'].r).unsqueeze(0))
        self.register_buffer('th_shapedirs',
                             torch.Tensor(smpl_data['shapedirs'].r))
        self.register_buffer('th_posedirs',
                             torch.Tensor(smpl_data['posedirs'].r))
        self.register_buffer(
            'th_v_template',
            torch.Tensor(smpl_data['v_template'].r).unsqueeze(0))
        self.register_buffer(
            'th_J_regressor',
            torch.Tensor(np.array(smpl_data['J_regressor'].toarray())))
        self.register_buffer('th_weights',
                             torch.Tensor(smpl_data['weights'].r))
        self.register_buffer(
            'th_faces',
            torch.Tensor(smpl_data['f'].astype(np.int32)).long())

        # Get hand mean
        hands_mean = np.zeros(hands_components.shape[1]
                              ) if flat_hand_mean else smpl_data['hands_mean']
        hands_mean = hands_mean.copy()
        th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
        if self.use_pca:
            # Save as axis-angle
            self.register_buffer('th_hands_mean', th_hands_mean)
            selected_components = hands_components[:ncomps]
            self.register_buffer('th_selected_comps',
                                 torch.Tensor(selected_components))
        else:
            th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
                th_hands_mean.view(15, 3)).reshape(15, 3, 3)
            self.register_buffer('th_hands_mean_rotmat', th_hands_mean_rotmat)

        # Kinematic chain params
        self.kintree_table = smpl_data['kintree_table']
        parents = list(self.kintree_table[0].tolist())
        self.kintree_parents = parents
Exemple #12
0
    def __init__(
        self,
        center_idx=None,
        flat_hand_mean=True,
        ncomps=6,
        side="right",
        mano_root="mano/models",
        use_pca=True,
        root_rot_mode="axisang",
        joint_rot_mode="axisang",
        robust_rot=False,
        return_transf=False,
        return_full_pose=False,
    ):
        """
        Args:
            center_idx: index of center joint in our computations,
                if -1 centers on estimate of palm as middle of base
                of middle finger and wrist
            flat_hand_mean: if True, (0, 0, 0, ...) pose coefficients match
                flat hand, else match average hand pose
            mano_root: path to MANO pkl files for left and right hand
            ncomps: number of PCA components form pose space (<45)
            side: 'right' or 'left'
            use_pca: Use PCA decomposition for pose space.
            root_rot_mode: 'axisang' or 'rotmat' or 'quat',
            joint_rot_mode: 'axisang' or 'rotmat' or 'quat', ignored if use_pca
        """
        super().__init__()

        self.center_idx = center_idx
        self.robust_rot = robust_rot

        # check root_rot_mode feasible, and set self.rot
        if root_rot_mode == "axisang":
            self.rot = 3
        elif root_rot_mode == "rotmat":
            self.rot = 6
        elif root_rot_mode == "quat":
            self.rot = 4
        else:
            raise KeyError(
                "root_rot_mode not found. shoule be one of 'axisang' or 'rotmat' or 'quat'. got {}"
                .format(root_rot_mode))

        # todo: flat_hand_mean have issues
        self.flat_hand_mean = flat_hand_mean

        # toggle extra return information
        self.return_transf = return_transf
        self.return_full_pose = return_full_pose

        # record side of hands
        self.side = side

        # check use_pca and joint_rot_mode
        if use_pca and joint_rot_mode != "axisang":
            raise TypeError(
                "if use_pca, joint_rot_mode must be 'axisang'. got {}".format(
                    joint_rot_mode))
        # record use_pca flag and joint_rot_mode
        self.use_pca = use_pca
        self.joint_rot_mode = joint_rot_mode
        # self.ncomps only work in axisang mode
        if use_pca:
            self.ncomps = ncomps
        else:
            self.ncomps = 45

        # do more checks on root_rot_mode, in case mode error
        if self.joint_rot_mode == "axisang":
            # add restriction to root_rot_mode
            if root_rot_mode not in ["axisang", "rotmat"]:
                err_msg = "rot_mode not compatible, "
                err_msg += "when joint_rot_mode is 'axisang', root_rot_mode should be one of "
                err_msg += "'axisang' or 'rotmat', got {}".format(
                    root_rot_mode)
                raise KeyError(err_msg)
        else:
            # for 'rotmat' or 'quat', there rot_mode must be same for joint and root
            if root_rot_mode != self.joint_rot_mode:
                err_msg = "rot_mode not compatible, "
                err_msg += "should get the same rot mode for joint and root, "
                err_msg += "got {} for root and {} for joint".format(
                    root_rot_mode, self.joint_rot_mode)
                raise KeyError(err_msg)
        # record root_rot_mode
        self.root_rot_mode = root_rot_mode

        # load model according to side flag
        if side == "right":
            self.mano_path = os.path.join(mano_root, "MANO_RIGHT.pkl")
        elif side == "left":
            self.mano_path = os.path.join(mano_root, "MANO_LEFT.pkl")

        # parse and register stuff
        smpl_data = ready_arguments(self.mano_path)

        hands_components = smpl_data["hands_components"]

        self.smpl_data = smpl_data

        self.register_buffer(
            "th_betas",
            torch.Tensor(np.array(smpl_data["betas"].r)).unsqueeze(0))
        self.register_buffer("th_shapedirs",
                             torch.Tensor(np.array(smpl_data["shapedirs"].r)))
        self.register_buffer("th_posedirs",
                             torch.Tensor(np.array(smpl_data["posedirs"].r)))
        self.register_buffer(
            "th_v_template",
            torch.Tensor(np.array(smpl_data["v_template"].r)).unsqueeze(0))
        self.register_buffer(
            "th_J_regressor",
            torch.Tensor(np.array(smpl_data["J_regressor"].toarray())))
        self.register_buffer("th_weights",
                             torch.Tensor(np.array(smpl_data["weights"].r)))
        self.register_buffer(
            "th_faces",
            torch.Tensor(np.array(smpl_data["f"]).astype(np.int32)).long())

        # Get hand mean
        hands_mean = np.zeros(hands_components.shape[1]
                              ) if flat_hand_mean else smpl_data["hands_mean"]
        hands_mean = hands_mean.copy()
        th_hands_mean = torch.Tensor(hands_mean).unsqueeze(0)
        if self.use_pca or self.joint_rot_mode == "axisang":
            # Save as axis-angle
            self.register_buffer("th_hands_mean", th_hands_mean)
            selected_components = hands_components[:ncomps]
            self.register_buffer("th_selected_comps",
                                 torch.Tensor(selected_components))
        elif self.joint_rot_mode == "rotmat":
            th_hands_mean_rotmat = rodrigues_layer.batch_rodrigues(
                th_hands_mean.view(15, 3)).reshape(15, 3, 3)
            self.register_buffer("th_hands_mean_rotmat", th_hands_mean_rotmat)
        elif self.joint_rot_mode == "quat":
            # TODO deal with flat hand mean
            self.register_buffer("th_hands_mean_quat", None)
        else:
            raise KeyError(
                "joint_rot_mode not found. shoule be one of 'axisang' or 'rotmat' or 'quat'. got {}"
                .format(self.joint_rot_mode))

        # Kinematic chain params
        self.kintree_table = smpl_data["kintree_table"]
        parents = list(self.kintree_table[0].tolist())
        self.kintree_parents = parents