コード例 #1
0
ファイル: flame.py プロジェクト: lzhbrian/FaceGeometryTorch
    def get_vertices_and_3D_landmarks(self):
        pose_params = torch.cat([self.global_rot, self.jaw_pose], dim=1)
        shape_params = (self.fixed_shape
                        if self.fixed_shape is not None else self.shape_params)
        betas = torch.cat([shape_params, self.expression_params], dim=1)

        # pose_params_numpy[:, :3] : global rotation
        # pose_params_numpy[:, 3:] : jaw rotation
        full_pose = torch.cat([
            pose_params[:, :3], self.neck_pose, pose_params[:, 3:],
            self.eye_pose
        ],
                              dim=1)
        template_vertices = self.v_template.unsqueeze(0).repeat(
            self.batch_size, 1, 1)
        vertices, _ = lbs(betas,
                          full_pose,
                          template_vertices,
                          self.shapedirs,
                          self.posedirs,
                          self.J_regressor,
                          self.parents,
                          self.lbs_weights,
                          dtype=self.dtype)

        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0)
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0)
        if self.use_face_contour:

            dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                dtype=self.dtype)

            lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([dyn_lmk_bary_coords, lmk_bary_coords],
                                        1)

        landmarks_3d = vertices2landmarks(vertices, self.faces_tensor,
                                          lmk_faces_idx, lmk_bary_coords)

        landmarks_3d += self.transl.unsqueeze(dim=1)
        vertices += self.transl.unsqueeze(dim=1)

        landmarks_3d.squeeze_()
        vertices.squeeze_()
        return vertices, landmarks_3d
コード例 #2
0
    def forward(self, shape_params=None, expression_params=None, pose_params=None, neck_pose=None, eye_pose=None, transl=None):
        """
            Input:
                shape_params: N X number of shape parameters
                expression_params: N X number of expression parameters
                pose_params: N X number of pose parameters
            return:
                vertices: N X V X 3
                landmarks: N X number of landmarks X 3
        """
        betas = torch.cat([shape_params,self.shape_betas, expression_params, self.expression_betas], dim=1)
        neck_pose = (neck_pose if neck_pose is not None else self.neck_pose)
        eye_pose = (eye_pose if eye_pose is not None else self.eye_pose)
        transl = (transl if transl is not None else self.transl)
        full_pose = torch.cat([pose_params[:,:3], neck_pose, pose_params[:,3:], eye_pose], dim=1)
        template_vertices = self.v_template.unsqueeze(0).repeat(self.batch_size, 1, 1)

        vertices, _ = lbs(betas, full_pose, template_vertices,
                               self.shapedirs, self.posedirs,
                               self.J_regressor, self.parents,
                               # self.lbs_weights, dtype=self.dtype)
                               self.lbs_weights)

        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).repeat(
            self.batch_size, 1)
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
            self.batch_size, 1, 1)
        if self.use_face_contour:

            dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
                vertices, full_pose, self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain, dtype=self.dtype)

            lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat(
                [dyn_lmk_bary_coords, lmk_bary_coords], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor,
                                             lmk_faces_idx,
                                             lmk_bary_coords)

        if self.use_3D_translation:
            landmarks += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        return vertices, landmarks
コード例 #3
0
    def forward(self,
                shape_params=None,
                expression_params=None,
                pose_params=None,
                neck_pose=None,
                transl=None,
                eye_pose=None):
        """
            Input:
                shape_params: N X number of shape parameters
                expression_params: N X number of expression parameters
                pose_params: N X number of pose parameters
            return:
                vertices: N X V X 3
        """
        betas = torch.cat([shape_params, expression_params], dim=1)

        # If we don't specify eye_pose use the default
        eye_pose = (eye_pose if eye_pose is not None else self.eye_pose)

        full_pose = torch.cat(
            [pose_params[:, :3], neck_pose, pose_params[:, 3:], eye_pose],
            dim=1)

        template_vertices = self.v_template.unsqueeze(0).repeat(
            pose_params.shape[0], 1, 1)
        vertices, _ = lbs(betas,
                          full_pose,
                          template_vertices,
                          self.shapedirs,
                          self.posedirs,
                          self.J_regressor,
                          self.parents,
                          self.lbs_weights,
                          dtype=self.dtype)

        vertices += transl.unsqueeze(dim=1)

        return vertices
コード例 #4
0
ファイル: body_models.py プロジェクト: zhly0/eft
    def forward(self,
                betas=None,
                global_orient=None,
                body_pose=None,
                left_hand_pose=None,
                right_hand_pose=None,
                transl=None,
                expression=None,
                jaw_pose=None,
                leye_pose=None,
                reye_pose=None,
                return_verts=True,
                return_full_pose=False,
                pose2rot=True,
                **kwargs):
        '''
        Forward pass for the SMPLX model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. (default=None)
            betas: torch.tensor, optional, shape Bx10
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            expression: torch.tensor, optional, shape Bx10
                If given, ignore the member variable `expression` and use it
                instead. For example, it can used if expression parameters
                `expression` are predicted from some external model.
            body_pose: torch.tensor, optional, shape Bx(J*3)
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                axis-angle format. (default=None)
            left_hand_pose: torch.tensor, optional, shape BxP
                If given, ignore the member variable `left_hand_pose` and
                use this instead. It should either contain PCA coefficients or
                joint rotations in axis-angle format.
            right_hand_pose: torch.tensor, optional, shape BxP
                If given, ignore the member variable `right_hand_pose` and
                use this instead. It should either contain PCA coefficients or
                joint rotations in axis-angle format.
            jaw_pose: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `jaw_pose` and
                use this instead. It should either joint rotations in
                axis-angle format.
            transl: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `transl` and use it
                instead. For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full axis-angle pose vector (default=False)

            Returns
            -------
                output: ModelOutput
                A named tuple of type `ModelOutput`
        '''

        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient
                         if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas

        left_hand_pose = (left_hand_pose if left_hand_pose is not None else
                          self.left_hand_pose)
        right_hand_pose = (right_hand_pose if right_hand_pose is not None else
                           self.right_hand_pose)
        jaw_pose = jaw_pose if jaw_pose is not None else self.jaw_pose
        leye_pose = leye_pose if leye_pose is not None else self.leye_pose
        reye_pose = reye_pose if reye_pose is not None else self.reye_pose
        expression = expression if expression is not None else self.expression

        apply_trans = transl is not None or hasattr(self, 'transl')
        if transl is None:
            if hasattr(self, 'transl'):
                transl = self.transl

        if self.use_pca:
            left_hand_pose = torch.einsum(
                'bi,ij->bj', [left_hand_pose, self.left_hand_components])
            right_hand_pose = torch.einsum(
                'bi,ij->bj', [right_hand_pose, self.right_hand_components])

        batch_size = max(betas.shape[0], global_orient.shape[0],
                         body_pose.shape[0])

        #Fixed rot matrix input issue. Assuming hands and others are still angle axis
        if pose2rot == True:
            full_pose = torch.cat([
                global_orient, body_pose, jaw_pose, leye_pose, reye_pose,
                left_hand_pose, right_hand_pose
            ],
                                  dim=1)
            full_pose += self.pose_mean

        else:
            full_pose = torch.cat([global_orient, body_pose],
                                  dim=1)  #(N, 22, 3, 3)

            full_others = torch.cat([
                jaw_pose, leye_pose, reye_pose, left_hand_pose, right_hand_pose
            ],
                                    dim=1)
            # full_pose += self.pose_mean
            full_others += self.pose_mean[full_pose.shape[1] * 3:]
            # batchsize_size = full_pose.shape[0]

            full_pose_others = batch_rodrigues(full_others.view(-1, 3),
                                               dtype=self.dtype).view(
                                                   [batch_size, -1, 3,
                                                    3])  #(N, 33, 3, 3)

            full_pose = torch.cat([full_pose, full_pose_others],
                                  dim=1)  #(4,55, 3, 3)

        # Add the mean pose of the model. Does not affect the body, only the
        # hands when flat_hand_mean == False

        # Concatenate the shape and expression coefficients
        scale = int(batch_size / betas.shape[0])
        if scale > 1:
            betas = betas.expand(scale, -1)
        shape_components = torch.cat([betas, expression], dim=-1)

        vertices, joints = lbs(shape_components,
                               full_pose,
                               self.v_template,
                               self.shapedirs,
                               self.posedirs,
                               self.J_regressor,
                               self.parents,
                               self.lbs_weights,
                               pose2rot=pose2rot,
                               dtype=self.dtype)

        lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).expand(
            batch_size, -1).contiguous()
        lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(
            self.batch_size, 1, 1)
        if self.use_face_contour:
            dyn_lmk_faces_idx, dyn_lmk_bary_coords = find_dynamic_lmk_idx_and_bcoords(
                vertices,
                full_pose,
                self.dynamic_lmk_faces_idx,
                self.dynamic_lmk_bary_coords,
                self.neck_kin_chain,
                dtype=self.dtype)

            lmk_faces_idx = torch.cat([lmk_faces_idx, dyn_lmk_faces_idx], 1)
            lmk_bary_coords = torch.cat([
                lmk_bary_coords.expand(batch_size, -1, -1), dyn_lmk_bary_coords
            ], 1)

        landmarks = vertices2landmarks(vertices, self.faces_tensor,
                                       lmk_faces_idx, lmk_bary_coords)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        # Add the landmarks to the joints
        joints = torch.cat([joints, landmarks], dim=1)
        # Map the joints to the current dataset

        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints=joints, vertices=vertices)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = ModelOutput(vertices=vertices if return_verts else None,
                             joints=joints,
                             betas=betas,
                             expression=expression,
                             global_orient=self.global_orient,
                             body_pose=body_pose,
                             left_hand_pose=self.left_hand_pose,
                             right_hand_pose=self.right_hand_pose,
                             jaw_pose=jaw_pose,
                             full_pose=full_pose if return_full_pose else None)
        return output
コード例 #5
0
ファイル: body_models.py プロジェクト: zhly0/eft
    def forward(self,
                betas=None,
                global_orient=None,
                body_pose=None,
                left_hand_pose=None,
                right_hand_pose=None,
                transl=None,
                return_verts=True,
                return_full_pose=False,
                pose2rot=True,
                **kwargs):
        '''
        '''
        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient
                         if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas
        left_hand_pose = (left_hand_pose if left_hand_pose is not None else
                          self.left_hand_pose)
        right_hand_pose = (right_hand_pose if right_hand_pose is not None else
                           self.right_hand_pose)

        apply_trans = transl is not None or hasattr(self, 'transl')
        if transl is None:
            if hasattr(self, 'transl'):
                transl = self.transl

        if self.use_pca:
            left_hand_pose = torch.einsum(
                'bi,ij->bj', [left_hand_pose, self.left_hand_components])
            right_hand_pose = torch.einsum(
                'bi,ij->bj', [right_hand_pose, self.right_hand_components])

        full_pose = torch.cat(
            [global_orient, body_pose, left_hand_pose, right_hand_pose], dim=1)
        full_pose += self.pose_mean

        vertices, joints = lbs(self.betas,
                               full_pose,
                               self.v_template,
                               self.shapedirs,
                               self.posedirs,
                               self.J_regressor,
                               self.parents,
                               self.lbs_weights,
                               pose2rot=pose2rot,
                               dtype=self.dtype)

        # Add any extra joints that might be needed
        joints = self.vertex_joint_selector(vertices, joints)
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = ModelOutput(vertices=vertices if return_verts else None,
                             joints=joints,
                             betas=betas,
                             global_orient=global_orient,
                             body_pose=body_pose,
                             left_hand_pose=left_hand_pose,
                             right_hand_pose=right_hand_pose,
                             full_pose=full_pose if return_full_pose else None)

        return output
コード例 #6
0
ファイル: body_models.py プロジェクト: zhly0/eft
    def forward(self,
                betas=None,
                body_pose=None,
                global_orient=None,
                transl=None,
                return_verts=True,
                return_full_pose=False,
                pose2rot=True,
                **kwargs):
        ''' Forward pass for the SMPL model

            Parameters
            ----------
            global_orient: torch.tensor, optional, shape Bx3
                If given, ignore the member variable and use it as the global
                rotation of the body. Useful if someone wishes to predicts this
                with an external model. (default=None)
            betas: torch.tensor, optional, shape Bx10
                If given, ignore the member variable `betas` and use it
                instead. For example, it can used if shape parameters
                `betas` are predicted from some external model.
                (default=None)
            body_pose: torch.tensor, optional, shape Bx(J*3)
                If given, ignore the member variable `body_pose` and use it
                instead. For example, it can used if someone predicts the
                pose of the body joints are predicted from some external model.
                It should be a tensor that contains joint rotations in
                axis-angle format. (default=None)
            transl: torch.tensor, optional, shape Bx3
                If given, ignore the member variable `transl` and use it
                instead. For example, it can used if the translation
                `transl` is predicted from some external model.
                (default=None)
            return_verts: bool, optional
                Return the vertices. (default=True)
            return_full_pose: bool, optional
                Returns the full axis-angle pose vector (default=False)

            Returns
            -------
        '''
        # If no shape and pose parameters are passed along, then use the
        # ones from the module
        global_orient = (global_orient
                         if global_orient is not None else self.global_orient)
        body_pose = body_pose if body_pose is not None else self.body_pose
        betas = betas if betas is not None else self.betas

        apply_trans = transl is not None or hasattr(self, 'transl')
        if transl is None and hasattr(self, 'transl'):
            transl = self.transl

        full_pose = torch.cat([global_orient, body_pose], dim=1)

        batch_size = max(betas.shape[0], global_orient.shape[0],
                         body_pose.shape[0])

        if betas.shape[0] != batch_size:
            num_repeats = int(batch_size / betas.shape[0])
            betas = betas.expand(num_repeats, -1)

        vertices, joints = lbs(betas,
                               full_pose,
                               self.v_template,
                               self.shapedirs,
                               self.posedirs,
                               self.J_regressor,
                               self.parents,
                               self.lbs_weights,
                               pose2rot=pose2rot,
                               dtype=self.dtype)

        joints = self.vertex_joint_selector(vertices, joints)
        # Map the joints to the current dataset
        if self.joint_mapper is not None:
            joints = self.joint_mapper(joints)

        if apply_trans:
            joints += transl.unsqueeze(dim=1)
            vertices += transl.unsqueeze(dim=1)

        output = ModelOutput(vertices=vertices if return_verts else None,
                             global_orient=global_orient,
                             body_pose=body_pose,
                             joints=joints,
                             betas=betas,
                             full_pose=full_pose if return_full_pose else None)

        return output
コード例 #7
0
    def forward(self,
                root_orient=None,
                pose_body=None,
                pose_hand=None,
                pose_jaw=None,
                pose_eye=None,
                betas=None,
                trans=None,
                **kwargs):
        '''

        :param root_orient: Nx3
        :param pose_body:
        :param pose_hand:
        :param pose_jaw:
        :param pose_eye:
        :param kwargs:
        :return:
        '''
        assert self.model_type in [
            'smpl', 'smplh', 'smplhf', 'mano_left', 'mano_right'
        ], ValueError(
            'model_type should be in smpl/smplh/smplhf/mano_left/mano_right.')
        if root_orient is None: root_orient = self.root_orient
        if self.model_type in ['smplh', 'smpl']:
            if pose_body is None: pose_body = self.pose_body
            if pose_hand is None: pose_hand = self.pose_hand
        elif self.model_type == 'smplhf':
            if pose_body is None: pose_body = self.pose_body
            if pose_hand is None: pose_hand = self.pose_hand
            if pose_jaw is None: pose_jaw = self.pose_jaw
            if pose_eye is None: pose_eye = self.pose_eye
        elif self.model_type in ['mano_left', 'mano_right']:
            if pose_hand is None: pose_hand = self.pose_hand

        if trans is None: trans = self.trans
        if betas is None: betas = self.betas

        if self.model_type in ['smplh', 'smpl']:
            full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=1)
        elif self.model_type == 'smplhf':
            full_pose = torch.cat(
                [root_orient, pose_body, pose_jaw, pose_eye, pose_hand], dim=1
            )  # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr
        elif self.model_type in ['mano_left', 'mano_right']:
            full_pose = torch.cat([root_orient, pose_hand], dim=1)

        if self.model_type == 'smplhf':
            shape_components = torch.cat([betas, self.expression], dim=-1)
            shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1)
        else:
            shape_components = betas
            shapedirs = self.shapedirs

        verts, joints = lbs(betas=shape_components,
                            pose=full_pose,
                            v_template=self.v_template,
                            shapedirs=shapedirs,
                            posedirs=self.posedirs,
                            J_regressor=self.J_regressor,
                            parents=self.kintree_table[0].long(),
                            lbs_weights=self.weights,
                            num_joints=int(full_pose.shape[1] / 3),
                            dtype=self.dtype)

        Jtr = joints + trans.unsqueeze(dim=1)
        verts = verts + trans.unsqueeze(dim=1)

        class result_meta(object):
            pass

        res = result_meta()
        res.v = verts
        res.f = self.f
        res.betas = self.betas
        res.Jtr = Jtr  #Todo: ik can be made with vposer

        if self.model_type == 'smpl':
            res.pose_body = pose_body
        elif self.model_type == 'smplh':
            res.pose_body = pose_body
            res.pose_hand = pose_hand
        elif self.model_type == 'smplhf':
            res.pose_body = pose_body
            res.pose_hand = pose_hand
            res.pose_jaw = pose_jaw
            res.pose_eye = pose_eye
        elif self.model_type in ['mano_left', 'mano_right']:
            res.pose_hand = pose_hand
        res.full_pose = full_pose

        return res
コード例 #8
0
    def forward(self,
                root_orient=None,
                pose_body=None,
                pose_hand=None,
                pose_jaw=None,
                pose_eye=None,
                betas=None,
                trans=None,
                dmpls=None,
                expression=None,
                return_dict=False,
                v_template=None,
                **kwargs):
        '''

        :param root_orient: Nx3
        :param pose_body:
        :param pose_hand:
        :param pose_jaw:
        :param pose_eye:
        :param kwargs:
        :return:
        '''
        assert not (v_template is not None and betas is not None), ValueError(
            'vtemplate and betas could not be used jointly.')
        assert self.model_type in [
            'smpl', 'smplh', 'smplx', 'mano'
        ], ValueError('model_type should be in smpl/smplh/smplx/mano')
        if root_orient is None: root_orient = self.root_orient
        if self.model_type in ['smpl', 'smplh']:
            if pose_body is None: pose_body = self.pose_body
            if pose_hand is None: pose_hand = self.pose_hand
        elif self.model_type == 'smplx':
            if pose_body is None: pose_body = self.pose_body
            if pose_hand is None: pose_hand = self.pose_hand
            if pose_jaw is None: pose_jaw = self.pose_jaw
            if pose_eye is None: pose_eye = self.pose_eye
        elif self.model_type in ['mano']:
            if pose_hand is None: pose_hand = self.pose_hand

        if trans is None: trans = self.trans
        if v_template is None: v_template = self.v_template
        if betas is None: betas = self.betas

        if self.model_type in ['smpl', 'smplh']:
            full_pose = torch.cat([root_orient, pose_body, pose_hand], dim=1)
        elif self.model_type == 'smplx':
            full_pose = torch.cat(
                [root_orient, pose_body, pose_jaw, pose_eye, pose_hand], dim=1
            )  # orient:3, body:63, jaw:3, eyel:3, eyer:3, handl, handr
        elif self.model_type in ['mano']:
            full_pose = torch.cat([root_orient, pose_hand], dim=1)

        if self.use_dmpl:
            if dmpls is None: dmpls = self.dmpls
            shape_components = torch.cat([betas, dmpls], dim=-1)
            shapedirs = torch.cat([self.shapedirs, self.dmpldirs], dim=-1)
        elif self.model_type == 'smplx':
            if expression is None: expression = self.expression
            shape_components = torch.cat([betas, expression], dim=-1)
            shapedirs = torch.cat([self.shapedirs, self.exprdirs], dim=-1)
        else:
            shape_components = betas
            shapedirs = self.shapedirs

        verts, joints = lbs(betas=shape_components,
                            pose=full_pose,
                            v_template=v_template,
                            shapedirs=shapedirs,
                            posedirs=self.posedirs,
                            J_regressor=self.J_regressor,
                            parents=self.kintree_table[0].long(),
                            lbs_weights=self.weights,
                            dtype=self.dtype)

        Jtr = joints + trans.unsqueeze(dim=1)
        verts = verts + trans.unsqueeze(dim=1)

        res = {}
        res['v'] = verts
        res['f'] = self.f
        res['betas'] = self.betas
        res['Jtr'] = Jtr  # Todo: ik can be made with vposer

        if self.model_type == 'smpl':
            res['pose_body'] = pose_body
        elif self.model_type == 'smplh':
            res['pose_body'] = pose_body
            res['pose_hand'] = pose_hand
        elif self.model_type == 'smplx':
            res['pose_body'] = pose_body
            res['pose_hand'] = pose_hand
            res['pose_jaw'] = pose_jaw
            res['pose_eye'] = pose_eye
        elif self.model_type in ['mano']:
            res['pose_hand'] = pose_hand
        res['full_pose'] = full_pose

        if not return_dict:

            class result_meta(object):
                pass

            res_class = result_meta()
            for k, v in res.items():
                res_class.__setattr__(k, v)
            res = res_class

        return res
コード例 #9
0
    def forward(self,
                shape_params=None,
                expression_params=None,
                pose_params=None,
                neck_pose=None,
                eye_pose=None,
                transl=None):
        """
            Input:
                shape_params: N X number of shape parameters
                expression_params: N X number of expression parameters
                pose_params: N X number of pose parameters
            return:
                vertices: N X V X 3
                landmarks: N X number of landmarks X 3
        """
        batch_size = pose_params.shape[0]
        betas = torch.cat([
            shape_params, self.shape_betas, expression_params,
            self.expression_betas
        ],
                          dim=1)
        neck_pose = (neck_pose if neck_pose is not None else self.neck_pose)
        eye_pose = (eye_pose if eye_pose is not None else self.eye_pose)
        transl = (transl if transl is not None else self.transl)
        full_pose = torch.cat(
            [pose_params[:, :3], neck_pose, neck_pose, eye_pose], dim=1)
        template_vertices = self.v_template.unsqueeze(0).repeat(
            self.batch_size, 1, 1)

        vertices, _ = lbs(betas,
                          full_pose,
                          template_vertices,
                          self.shapedirs,
                          self.posedirs,
                          self.J_regressor,
                          self.parents,
                          self.lbs_weights,
                          dtype=self.dtype)

        print('vertices', vertices.shape)
        scaled = (-torch.abs(pose_params[:, 4].unsqueeze(1)) *
                  vertices).reshape(batch_size, 3, -1)  #
        scaled[:, :2, :] = scaled[:, :2, :] + pose_params[:, 5:7].view(
            batch_size, 2, 1)

        # lmk_faces_idx = self.lmk_faces_idx.unsqueeze(dim=0).repeat(
        #     self.batch_size, 1)
        # lmk_bary_coords = self.lmk_bary_coords.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)

        # if self.use_face_contour:

        #     dyn_lmk_faces_idx, dyn_lmk_bary_coords = self._find_dynamic_lmk_idx_and_bcoords(
        #         face_pixels, full_pose, self.dynamic_lmk_faces_idx,
        #         self.dynamic_lmk_bary_coords,
        #         self.neck_kin_chain, dtype=self.dtype)

        #     lmk_faces_idx = torch.cat([dyn_lmk_faces_idx, lmk_faces_idx], 1)
        #     lmk_bary_coords = torch.cat(
        #         [dyn_lmk_bary_coords, lmk_bary_coords], 1)
        # landmarks = vertices2landmarks(face_pixels, self.faces_tensor,lmk_faces_idx,lmk_bary_coords)

        return vertices