def forward(self, *args, **kwargs):
     kwargs['get_skin'] = True
     smpl_output = super(SMPL, self).forward(*args, **kwargs)
     extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
     h36m_joints = vertices2joints(self.J_regressor_h36m, smpl_output.vertices)
     all_joints = torch.cat([smpl_output.joints, extra_joints, h36m_joints], dim=1)
     output = ModelOutput(vertices=smpl_output.vertices,
                          global_orient=smpl_output.global_orient,
                          body_pose=smpl_output.body_pose,
                          joints=all_joints,
                          betas=smpl_output.betas,
                          full_pose=smpl_output.full_pose)
     return output
def get_default_pose(v_template, betas, shapedirs, J_regressor):
    from smplx.lbs import vertices2joints, blend_shapes
    v_shaped = v_template + blend_shapes(betas, shapedirs)

    # Get the joints
    # NxJx3 array
    J = vertices2joints(J_regressor, v_shaped)
    return c2c(J[0])
Exemple #3
0
    def forward(self, *args, **kwargs):
        kwargs['get_skin'] = True

        #if pose parameter is for SMPL with 21 joints (ignoring root)
        if (kwargs['body_pose'].shape[1] == 69):
            kwargs['body_pose'] = kwargs[
                'body_pose'][:, :-2 *
                             3]  #Ignore the last two joints (which are on the palm. Not used)

        if (kwargs['body_pose'].shape[1] == 23):
            kwargs['body_pose'] = kwargs[
                'body_pose'][:, :
                             -2]  #Ignore the last two joints (which are on the palm. Not used)

        smpl_output = super(SMPLX, self).forward(*args, **kwargs)
        extra_joints = vertices2joints(self.J_regressor_extra,
                                       smpl_output.vertices)
        # extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices[:,:6890])   *0      #TODO: implement this correctly

        #SMPL-X Joint order: https://docs.google.com/spreadsheets/d/1_1dLdaX-sbMkCKr_JzJW_RZCpwBwd7rcKkWT_VgAQ_0/edit#gid=0
        smplx_to_smpl = list(range(0, 22)) + [28, 43] + list(range(
            55, 76))  #28 left middle finger , 43: right middle finger 1
        smpl_joints = smpl_output.joints[:,
                                         smplx_to_smpl, :]  #Convert SMPL-X to SMPL     127 ->45
        joints = torch.cat(
            [smpl_joints, extra_joints], dim=1
        )  #[N, 127, 3]->[N, 45, 3]  + [N, 9, 3]        #SMPL-X has more joints. should convert 45
        joints = joints[:, self.joint_map, :]

        # Hand joints
        smplx_hand_to_panoptic = [
            0, 13, 14, 15, 16, 1, 2, 3, 17, 4, 5, 6, 18, 10, 11, 12, 19, 7, 8,
            9, 20
        ]  #Wrist Thumb to Pinky

        smplx_lhand = [20] + list(range(25, 40)) + list(range(
            66, 71))  #20 for left wrist. 20 finger joints
        lhand_joints = smpl_output.joints[:, smplx_lhand, :]  #(N,21,3)
        lhand_joints = lhand_joints[:,
                                    smplx_hand_to_panoptic, :]  #Convert SMPL-X hand order to paonptic hand order

        smplx_rhand = [21] + list(range(40, 55)) + list(range(
            71, 76))  #21 for right wrist. 20 finger joints
        rhand_joints = smpl_output.joints[:, smplx_rhand, :]  #(N,21,3)
        rhand_joints = rhand_joints[:,
                                    smplx_hand_to_panoptic, :]  #Convert SMPL-X hand order to paonptic hand order

        output = ModelOutput(
            vertices=smpl_output.vertices,
            global_orient=smpl_output.global_orient,
            body_pose=smpl_output.body_pose,
            joints=joints,
            right_hand_joints=rhand_joints,  #N,21,3
            left_hand_joints=lhand_joints,  #N,21,3
            betas=smpl_output.betas,
            full_pose=smpl_output.full_pose)
        return output
Exemple #4
0
 def forward(self, *args, **kwargs):
     kwargs['get_skin'] = True
     smpl_output = super(SMPL, self).forward(*args, **kwargs)
     joints = vertices2joints(self.J_regressor_cocoplus,
                              smpl_output.vertices)[:, :14]
     output = ModelOutput(vertices=smpl_output.vertices,
                          global_orient=smpl_output.global_orient,
                          body_pose=smpl_output.body_pose,
                          joints=joints,
                          betas=smpl_output.betas,
                          full_pose=smpl_output.full_pose)
     return output
Exemple #5
0
 def forward(self, *args, **kwargs):
     kwargs['get_skin'] = True
     smpl_output = super(SMPL, self).forward(*args, **kwargs)
     extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
     joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
     joints = joints[:, self.joint_map, :]
     output = ModelOutput(vertices=smpl_output.vertices,                 # 定点
                          global_orient=smpl_output.global_orient,       #
                          body_pose=smpl_output.body_pose,               # pose
                          joints=joints,                                 # 关节
                          betas=smpl_output.betas,                       # shape
                          full_pose=smpl_output.full_pose)               #
     return output
Exemple #6
0
 def forward(self, *args, **kwargs):
     kwargs['get_skin'] = True
     smpl_output = super(SMPL, self).forward(*args, **kwargs)
     extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)        #Additional 9 joints #Check doc/J_regressor_extra.png
     joints = torch.cat([smpl_output.joints, extra_joints], dim=1)               #[N, 24 + 21, 3]  + [N, 9, 3]
     joints = joints[:, self.joint_map, :]
     output = ModelOutput(vertices=smpl_output.vertices,
                          global_orient=smpl_output.global_orient,
                          body_pose=smpl_output.body_pose,
                          joints=joints,
                          betas=smpl_output.betas,
                          full_pose=smpl_output.full_pose)
     return output
Exemple #7
0
 def forward(self, *args, **kwargs):
     kwargs['get_skin'] = True
     smpl_output = super().forward(*args, **kwargs)
     extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
     # smpl_output.joints: [B, 45, 3]  extra_joints: [B, 9, 3]
     vertices = smpl_output.vertices
     joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
     smpl_joints = smpl_output.joints[:, :24]
     joints = joints[:, self.joint_map, :]   # [B, 49, 3]
     joints_J24 = joints[:, -24:, :]
     joints_J19 = joints_J24[:, constants.J24_TO_J19, :]
     output = self.ModelOutput(vertices=vertices,
                               global_orient=smpl_output.global_orient,
                               body_pose=smpl_output.body_pose,
                               joints=joints,
                               joints_J19=joints_J19,
                               smpl_joints=smpl_joints,
                               betas=smpl_output.betas,
                               full_pose=smpl_output.full_pose)
     return output
Exemple #8
0
    def get_all_from_pose(self,
                          poses,
                          regress=True,
                          th_betas=torch.zeros(1),
                          th_trans=torch.zeros(1)):

        verts, Jtr = self.get_vert_from_pose(poses,
                                             th_betas=th_betas,
                                             th_trans=th_trans)
        Jtr = self.vertex_joint_selector(verts, Jtr)

        extra_joints = vertices2joints(self.J_regressor_extra,
                                       verts).to(self.device)
        joints = torch.cat([Jtr, extra_joints], dim=1).to(self.device)
        # print(joints.shape, self.joint_map, extra_joints.shape)
        joints = joints[:, self.joint_map, :]

        if regress:
            J_regressor_batch = self.J_regressor[None, :].expand(
                verts.shape[0], -1, -1).to(self.device)
            joints_36m = torch.matmul(J_regressor_batch, verts)
        # Joints has 49 joints, Joints_36m has 17 joints
        return verts.to(self.device), joints.to(self.device), joints_36m.to(
            self.device)