def render_shape(self,
                     vertices,
                     transformed_vertices,
                     images=None,
                     lights=None):
        batch_size = vertices.shape[0]
        if lights is None:
            light_positions = torch.tensor([[-0.1, -0.1, 0.2],
                                            [0, 0, 1]])[None, :, :].expand(
                                                batch_size, -1, -1).float()
            light_intensities = torch.ones_like(light_positions).float()
            lights = torch.cat((light_positions, light_intensities),
                               2).to(vertices.device)

        ## rasterizer near 0 far 100. move mesh so minz larger than 0
        transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10

        # Attributes
        face_vertices = util.face_vertices(
            vertices, self.faces.expand(batch_size, -1, -1))
        normals = util.vertex_normals(vertices,
                                      self.faces.expand(batch_size, -1, -1))
        face_normals = util.face_vertices(
            normals, self.faces.expand(batch_size, -1, -1))
        transformed_normals = util.vertex_normals(
            transformed_vertices, self.faces.expand(batch_size, -1, -1))
        transformed_face_normals = util.face_vertices(
            transformed_normals, self.faces.expand(batch_size, -1, -1))
        # render
        attributes = torch.cat([
            self.face_colors.expand(batch_size, -1, -1, -1),
            transformed_face_normals.detach(),
            face_vertices.detach(),
            face_normals.detach()
        ], -1)
        rendering = self.rasterizer(transformed_vertices,
                                    self.faces.expand(batch_size, -1, -1),
                                    attributes)
        # albedo
        albedo_images = rendering[:, :3, :, :]
        # shading
        normal_images = rendering[:, 9:12, :, :].detach()
        if lights.shape[1] == 9:
            shading_images = self.add_SHlight(normal_images, lights)
        else:
            print('directional')
            shading = self.add_directionlight(
                normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                lights)

            shading_images = shading.reshape([
                batch_size, lights.shape[1], albedo_images.shape[2],
                albedo_images.shape[3], 3
            ]).permute(0, 1, 4, 2, 3)
            shading_images = shading_images.mean(1)
        images = albedo_images * shading_images

        return images
    def forward(self, vertices, transformed_vertices, albedos, lights=None, light_type='point'):
        '''
        lihgts:
            spherical homarnic: [N, 9(shcoeff), 3(rgb)]
        vertices: [N, V, 3], vertices in work space, for calculating normals, then shading
        transformed_vertices: [N, V, 3], range(-1, 1), projected vertices, for rendering
        '''
        batch_size = vertices.shape[0]
        ## rasterizer near 0 far 100. move mesh so minz larger than 0
        transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10

        # Attributes
        face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1))
        normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1))
        face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1))
        transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1))
        transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1))

        # render
        attributes = torch.cat([self.face_uvcoords.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(),
                                face_vertices.detach(), face_normals.detach()], -1)
        # import ipdb;ipdb.set_trace()
        rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes)

        alpha_images = rendering[:, -1, :, :][:, None, :, :].detach()

        # albedo
        uvcoords_images = rendering[:, :3, :, :]
        grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2]

        albedo_images = F.grid_sample(albedos, grid, align_corners=False)

        # remove inner mouth region
        transformed_normal_map = rendering[:, 3:6, :, :].detach()
        pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float()

        # shading
        if lights is not None:
            normal_images = rendering[:, 9:12, :, :].detach()
            if lights.shape[1] == 9:
                shading_images = self.add_SHlight(normal_images, lights)
            else:
                if light_type == 'point':
                    vertice_images = rendering[:, 6:9, :, :].detach()
                    shading = self.add_pointlight(vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                                                  normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                                                  lights)
                    shading_images = shading.reshape(
                        [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1,
                                                                                                                  4, 2,
                                                                                                                  3)
                    shading_images = shading_images.mean(1)
                else:
                    shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]),
                                                      lights)
                    shading_images = shading.reshape(
                        [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1,
                                                                                                                  4, 2,
                                                                                                                  3)
                    shading_images = shading_images.mean(1)
            images = albedo_images * shading_images
        else:
            images = albedo_images
            shading_images = images.detach() * 0.

        outputs = {
            'images': images * alpha_images,
            'albedo_images': albedo_images,
            'alpha_images': alpha_images,
            'pos_mask': pos_mask,
            'shading_images': shading_images,
            'grid': grid,
            'normals': normals
        }

        return outputs