def render_shape(self, vertices, transformed_vertices, images=None, lights=None): batch_size = vertices.shape[0] if lights is None: light_positions = torch.tensor([[-0.1, -0.1, 0.2], [0, 0, 1]])[None, :, :].expand( batch_size, -1, -1).float() light_intensities = torch.ones_like(light_positions).float() lights = torch.cat((light_positions, light_intensities), 2).to(vertices.device) ## rasterizer near 0 far 100. move mesh so minz larger than 0 transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 # Attributes face_vertices = util.face_vertices( vertices, self.faces.expand(batch_size, -1, -1)) normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) face_normals = util.face_vertices( normals, self.faces.expand(batch_size, -1, -1)) transformed_normals = util.vertex_normals( transformed_vertices, self.faces.expand(batch_size, -1, -1)) transformed_face_normals = util.face_vertices( transformed_normals, self.faces.expand(batch_size, -1, -1)) # render attributes = torch.cat([ self.face_colors.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), face_vertices.detach(), face_normals.detach() ], -1) rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) # albedo albedo_images = rendering[:, :3, :, :] # shading normal_images = rendering[:, 9:12, :, :].detach() if lights.shape[1] == 9: shading_images = self.add_SHlight(normal_images, lights) else: print('directional') shading = self.add_directionlight( normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights) shading_images = shading.reshape([ batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3 ]).permute(0, 1, 4, 2, 3) shading_images = shading_images.mean(1) images = albedo_images * shading_images return images
def forward(self, vertices, transformed_vertices, albedos, lights=None, light_type='point'): ''' lihgts: spherical homarnic: [N, 9(shcoeff), 3(rgb)] vertices: [N, V, 3], vertices in work space, for calculating normals, then shading transformed_vertices: [N, V, 3], range(-1, 1), projected vertices, for rendering ''' batch_size = vertices.shape[0] ## rasterizer near 0 far 100. move mesh so minz larger than 0 transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10 # Attributes face_vertices = util.face_vertices(vertices, self.faces.expand(batch_size, -1, -1)) normals = util.vertex_normals(vertices, self.faces.expand(batch_size, -1, -1)) face_normals = util.face_vertices(normals, self.faces.expand(batch_size, -1, -1)) transformed_normals = util.vertex_normals(transformed_vertices, self.faces.expand(batch_size, -1, -1)) transformed_face_normals = util.face_vertices(transformed_normals, self.faces.expand(batch_size, -1, -1)) # render attributes = torch.cat([self.face_uvcoords.expand(batch_size, -1, -1, -1), transformed_face_normals.detach(), face_vertices.detach(), face_normals.detach()], -1) # import ipdb;ipdb.set_trace() rendering = self.rasterizer(transformed_vertices, self.faces.expand(batch_size, -1, -1), attributes) alpha_images = rendering[:, -1, :, :][:, None, :, :].detach() # albedo uvcoords_images = rendering[:, :3, :, :] grid = (uvcoords_images).permute(0, 2, 3, 1)[:, :, :, :2] albedo_images = F.grid_sample(albedos, grid, align_corners=False) # remove inner mouth region transformed_normal_map = rendering[:, 3:6, :, :].detach() pos_mask = (transformed_normal_map[:, 2:, :, :] < -0.05).float() # shading if lights is not None: normal_images = rendering[:, 9:12, :, :].detach() if lights.shape[1] == 9: shading_images = self.add_SHlight(normal_images, lights) else: if light_type == 'point': vertice_images = rendering[:, 6:9, :, :].detach() shading = self.add_pointlight(vertice_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights) shading_images = shading.reshape( [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, 4, 2, 3) shading_images = shading_images.mean(1) else: shading = self.add_directionlight(normal_images.permute(0, 2, 3, 1).reshape([batch_size, -1, 3]), lights) shading_images = shading.reshape( [batch_size, lights.shape[1], albedo_images.shape[2], albedo_images.shape[3], 3]).permute(0, 1, 4, 2, 3) shading_images = shading_images.mean(1) images = albedo_images * shading_images else: images = albedo_images shading_images = images.detach() * 0. outputs = { 'images': images * alpha_images, 'albedo_images': albedo_images, 'alpha_images': alpha_images, 'pos_mask': pos_mask, 'shading_images': shading_images, 'grid': grid, 'normals': normals } return outputs