Esempio n. 1
0
class TexturedFLAME(FLAME):
    def __init__(self, config, crop_size, device):
        super(TexturedFLAME, self).__init__(config)
        self.crop_size = crop_size
        texture_model = np.load(config.texture_path)

        self.texture_shape = texture_model['mean'].shape
        self.texture_num_pc = texture_model['tex_dir'].shape[-1]
        self.register_buffer(
            'texture_mean',
            torch.reshape(
                torch.from_numpy(texture_model['mean']).to(dtype=self.dtype),
                (1, -1)))
        self.register_buffer(
            'texture_dir',
            torch.reshape(
                torch.from_numpy(
                    texture_model['tex_dir']).to(dtype=self.dtype),
                (-1, self.texture_num_pc)).t())

        # faces_uvs = torch.cat(flame_conf.batch_size * [
        #     torch.tensor(np.int64(faces_uvs), dtype=torch.int64, device=devices[0]).unsqueeze(0)])
        # # verts_uvs = torch.cat(flame_conf.batch_size * [torch.tensor(np.float32(verts_uvs), dtype=torch.float32, device=devices[0]).unsqueeze(0)])
        self.register_buffer(
            'faces_uvs',
            torch.cat(
                self.batch_size *
                [torch.from_numpy(np.int64(texture_model['ft'])).unsqueeze(0)
                 ]))
        self.register_buffer(
            'verts_uvs',
            torch.cat(self.batch_size * [
                torch.from_numpy(
                    texture_model['vt']).to(dtype=self.dtype).unsqueeze(0)
            ]))

        self.register_parameter(
            'texture_params',
            nn.Parameter(torch.zeros((1, self.texture_num_pc),
                                     dtype=self.dtype,
                                     requires_grad=True),
                         requires_grad=True))

        raster_settings = RasterizationSettings(image_size=crop_size,
                                                blur_radius=0.0,
                                                faces_per_pixel=1,
                                                bin_size=None,
                                                max_faces_per_bin=None)

        R, T = look_at_view_transform(1.0, 0.5, 0)
        self.register_buffer('renderer_R', R)
        self.register_buffer('renderer_T', T)
        self.renderer_camera = OpenGLPerspectiveCameras(R=self.renderer_R,
                                                        T=self.renderer_T,
                                                        fov=20,
                                                        device=device)
        renderer_lights = DirectionalLights(
            direction=[[0.0, 0.0, 2.0]],
            specular_color=[[0.0, 0.0, 0.0]],
            device=device)  # PointLights(location=[[0.0, 0.0, -1.0]])
        renderer_rasterizer = MeshRasterizer(cameras=self.renderer_camera,
                                             raster_settings=raster_settings)
        renderer_shader = TexturedSoftPhongShader(cameras=self.renderer_camera,
                                                  lights=renderer_lights,
                                                  device=device)
        self.renderer = MeshRenderer(rasterizer=renderer_rasterizer,
                                     shader=renderer_shader)

    def forward(self,
                shape_params=None,
                expression_params=None,
                pose_params=None,
                neck_pose=None,
                eye_pose=None,
                transl=None,
                texture_params=None):
        vertices, landmarks = super(TexturedFLAME,
                                    self).forward(shape_params,
                                                  expression_params,
                                                  pose_params, neck_pose,
                                                  eye_pose, transl)
        texture = torch.reshape(
            torch.add(self.texture_mean,
                      torch.matmul(self.texture_params, self.texture_dir)),
            self.texture_shape)
        texture = texture.clamp(0.0, 255.0)
        texture = texture / 255.0
        texture = torch.cat(self.batch_size * [texture.unsqueeze(0)])

        textures = Textures(maps=texture,
                            faces_uvs=self.faces_uvs,
                            verts_uvs=self.verts_uvs)
        meshes = Meshes(
            vertices,
            torch.cat(vertices.shape[0] * [self.faces_tensor.unsqueeze(0)]),
            textures)

        images = bgr_to_rgb(
            self.renderer(meshes)[..., :3].permute(0, -1, 1, 2))
        # images = self.renderer(meshes)[..., :3].permute(0, -1, 1, 2)

        landmarks = self.transform_points(landmarks)
        landmarks[:, :, 0] *= -1
        landmarks[:, :, 1] *= -1

        for bi in range(landmarks.shape[0]):
            for pi in range(landmarks.shape[1]):
                landmarks[bi, pi,
                          0] = self._ndc_to_pix(landmarks[bi, pi, 0],
                                                self.crop_size)
                landmarks[bi, pi,
                          1] = self._ndc_to_pix(landmarks[bi, pi, 1],
                                                self.crop_size)
        landmarks = landmarks[:, :, :2]  # x y only
        return vertices, landmarks, images

    def transform_points(self, points):
        return self.renderer_camera.transform_points(points)

    def _ndc_to_pix(self, i, S):
        return ((i + 1) * S - 1.0) / 2.0
Esempio n. 2
0
    renderer = MeshRenderer(
        rasterizer=MeshRasterizer(
            cameras=cameras, 
            raster_settings=raster_settings
        ),
        shader=HardPhongShader(
            device=device, 
            cameras=cameras,
            lights=lights
        )
    )  
    images = renderer(mesh)
    print(images.size())

    transformed_face_shape = cameras.transform_points(face_shape)
    landmarks = transformed_face_shape[:, facemodel.keypoints, :]
    landmarks = ((landmarks + 1) * image_size - 1)/2.
    landmarks[:, :, :2] = image_size - landmarks[:, :, :2] #---x坐标和y坐标都需要倒置一下
    #print(landmarks)
    landmarks = landmarks.cpu().numpy()
    for i in range(batch_size):
        cropped_image = images_vec[i]
        cropped_image = torch.tensor(cropped_image).float().to(device)
        a_image = images[i, ..., :3]
        a_image = a_image + torch.min(a_image)
        a_image = a_image.clamp(0, 255)
        print("cropped_image.size:\t", cropped_image.size())
        print('a_image.size:\t', a_image.size())
        index = (a_image > 0)
        cropped_image[index] = a_image[index]