def phong_shading(meshes, fragments, lights, cameras, materials, texels) -> torch.Tensor: """ Apply per pixel shading. First interpolate the vertex normals and vertex coordinates using the barycentric coordinates to get the position and normal at each pixel. Then compute the illumination for each pixel. The pixel color is obtained by multiplying the pixel textures by the ambient and diffuse illumination and adding the specular component. Args: meshes: Batch of meshes fragments: Fragments named tuple with the outputs of rasterization lights: Lights class containing a batch of lights cameras: Cameras class containing a batch of cameras materials: Materials class containing a batch of material properties texels: texture per pixel of shape (N, H, W, K, 3) Returns: colors: (N, H, W, K, 3) """ verts = meshes.verts_packed() # (V, 3) faces = meshes.faces_packed() # (F, 3) vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_verts = verts[faces] faces_normals = vertex_normals[faces] pixel_coords = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts) pixel_normals = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_normals) ambient, diffuse, specular = _apply_lighting(pixel_coords, pixel_normals, lights, cameras, materials) colors = (ambient + diffuse) * texels + specular return colors
def neural_shading( meshes, fragments, lights, cameras, pixels_uv, texels, NN ) -> torch.Tensor: verts = meshes.verts_packed() # (V, 3) faces = meshes.faces_packed() # (F, 3) vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_verts = verts[faces] faces_normals = vertex_normals[faces] pixel_coords = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts) pixel_normals = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_normals) return _neural_shading( pixel_coords, pixel_normals, lights, cameras, pixels_uv, NN, )
def sample_textures(self, fragments, faces_packed=None) -> torch.Tensor: """ Detemine the color for each rasterized face. Interpolate the colors for vertices which form the face using the barycentric coordinates. Args: fragments: The outputs of rasterization. From this we use - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordianates of each pixel relative to the faces (in the packed representation) which overlap the pixel. Returns: texels: An texture per pixel of shape (N, H, W, K, C). There will be one C dimensional value for each element in fragments.pix_to_face. """ verts_features_packed = self.verts_features_packed() faces_verts_features = verts_features_packed[faces_packed] texels = interpolate_face_attributes( fragments.pix_to_face, fragments.bary_coords, faces_verts_features ) return texels
def _interpolate_zbuf(pix_to_face: torch.Tensor, barycentric_coords: torch.Tensor, meshes) -> torch.Tensor: """ A helper function to calculate the z buffer for each pixel in the rasterized output. Args: pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordinates of each pixel relative to the faces (in the packed representation) which overlap the pixel. meshes: Meshes object representing a batch of meshes. Returns: zbuffer: (N, H, W, K) FloatTensor """ verts = meshes.verts_packed() faces = meshes.faces_packed() faces_verts_z = verts[faces][..., 2][..., None] # (F, 3, 1) zbuf = interpolate_face_attributes(pix_to_face, barycentric_coords, faces_verts_z)[..., 0] # (1, H, W, K) zbuf[pix_to_face == -1] = -1 return zbuf
def interpolate_vertex_colors(fragments, meshes) -> torch.Tensor: """ Detemine the color for each rasterized face. Interpolate the colors for vertices which form the face using the barycentric coordinates. Args: meshes: A Meshes class representing a batch of meshes. fragments: The outputs of rasterization. From this we use - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordianates of each pixel relative to the faces (in the packed representation) which overlap the pixel. Returns: texels: An texture per pixel of shape (N, H, W, K, C). There will be one C dimensional value for each element in fragments.pix_to_face. """ vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C) vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] faces_packed = meshes.faces_packed() faces_textures = vertex_textures[faces_packed] # (F, 3, C) texels = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_textures) return texels
def forward(self, fragments, meshes, **kwargs): colors = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, self.face_colors) blend_params = BlendParams(sigma=1e-4, gamma=self.blend_gamma) imgs = soft_feature_blending(colors, fragments, blend_params=blend_params) return imgs
def gouraud_shading(meshes, fragments, lights, cameras, materials) -> torch.Tensor: """ Apply per vertex shading. First compute the vertex illumination by applying ambient, diffuse and specular lighting. If vertex color is available, combine the ambient and diffuse vertex illumination with the vertex color and add the specular component to determine the vertex shaded color. Then interpolate the vertex shaded colors using the barycentric coordinates to get a color per pixel. Gouraud shading is only supported for meshes with texture type `TexturesVertex`. This is because the illumination is applied to the vertex colors. Args: meshes: Batch of meshes fragments: Fragments named tuple with the outputs of rasterization lights: Lights class containing a batch of lights parameters cameras: Cameras class containing a batch of cameras parameters materials: Materials class containing a batch of material properties Returns: colors: (N, H, W, K, 3) """ if not isinstance(meshes.textures, TexturesVertex): raise ValueError("Mesh textures must be an instance of TexturesVertex") faces = meshes.faces_packed() # (F, 3) verts = meshes.verts_packed() # (V, 3) verts_normals = meshes.verts_normals_packed() # (V, 3) verts_colors = meshes.textures.verts_features_packed() # (V, D) vert_to_mesh_idx = meshes.verts_packed_to_mesh_idx() # Format properties of lights and materials so they are compatible # with the packed representation of the vertices. This transforms # all tensor properties in the class from shape (N, ...) -> (V, ...) where # V is the number of packed vertices. If the number of meshes in the # batch is one then this is not necessary. if len(meshes) > 1: lights = lights.clone().gather_props(vert_to_mesh_idx) cameras = cameras.clone().gather_props(vert_to_mesh_idx) materials = materials.clone().gather_props(vert_to_mesh_idx) # Calculate the illumination at each vertex ambient, diffuse, specular = _apply_lighting( verts, verts_normals, lights, cameras, materials ) verts_colors_shaded = verts_colors * (ambient + diffuse) + specular face_colors = verts_colors_shaded[faces] colors = interpolate_face_attributes( fragments.pix_to_face, fragments.bary_coords, face_colors ) return colors
def get_view_to_texture_map(self, meshes_world, **kwargs): fragments = self.rasterizer(meshes_world, **kwargs) tex = meshes_world.textures tex_map = tex.maps_padded() packing_list = [ i[j] for i, j in zip(tex.verts_uvs_list(), tex.faces_uvs_list()) ] faces_verts_uvs = torch.cat(packing_list) pixel_uvs = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs) return pixel_uvs
def forward(self, fragments, meshes, **kwargs) -> torch.Tensor: """" LightfieldShader just returns interpolated vertex coordinates for a plane """ # get the faces, normals, and textures from the mesh faces = meshes.faces_packed() # (F, 3) verts = meshes.verts_packed(); # (V, 3) faces_verts = verts[faces] Nv, H_out, W_out, K = fragments.pix_to_face.shape; # pixel_verts: (Nv, H, W, K=1, 3) -> (Nv, K=1, H, W, 3) -> (Nv*K=1, H, W, 3) K = 1; pixel_verts = interpolate_face_attributes( fragments.pix_to_face[:,:,:,0:K], fragments.bary_coords[:,:,:,0:K,:], faces_verts) pixel_verts = pixel_verts.permute(0, 3, 1, 2, 4).view(Nv * K, H_out, W_out, 3) return pixel_verts
def sample_pixel_uvs(self, fragments, **kwargs) -> torch.Tensor: """ Copied from super().sample_textures and adapted to output pixel_uvs instead of the sampled texture. Args: fragments: The outputs of rasterization. From this we use - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordianates of each pixel relative to the faces (in the packed representation) which overlap the pixel. Returns: texels: tensor of shape (N, H, W, K, C) giving the interpolated texture for each pixel in the rasterized image. """ if self.isempty(): faces_verts_uvs = torch.zeros((self._N, 3, 2), dtype=torch.float32, device=self.device) else: packing_list = [ i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) ] faces_verts_uvs = torch.cat(packing_list) # Each vertex yields 3 triangles with u,v coordinates (N, 3, 2) # pixel_uvs: (N, H, W, K, 2) pixel_uvs = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs) N, H_out, W_out, K = fragments.pix_to_face.shape # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) pixel_uvs = pixel_uvs * 2.0 - 1.0 return pixel_uvs
def debug_shading(meshes, fragments) -> torch.Tensor: return fragments.bary_coords faces = meshes.faces_packed() # (F, 3) vertex_normals = meshes.verts_normals_packed() # (V, 3) faces_normals = vertex_normals[faces] pixel_normals = interpolate_face_attributes( fragments.pix_to_face, fragments.bary_coords, faces_normals ) return pixel_normals verts = meshes.verts_packed() # (V, 3) faces = meshes.faces_packed() # (F, 3) face_normals = meshes.faces_normals_packed() # (V, 3) faces_verts = verts[faces] face_coords = faces_verts.mean(dim=-2) # (F, 3, XYZ) mean xyz across verts # Replace empty pixels in pix_to_face with 0 in order to interpolate. mask = fragments.pix_to_face == -1 pix_to_face = fragments.pix_to_face.clone() pix_to_face[mask] = 0 N, H, W, K = pix_to_face.shape idx = pix_to_face.view(N * H * W * K, 1).expand(N * H * W * K, 3) # gather pixel coords pixel_coords = face_coords.gather(0, idx).view(N, H, W, K, 3) pixel_coords[mask] = 0.0 # gather pixel normals pixel_normals = face_normals.gather(0, idx).view(N, H, W, K, 3) pixel_normals[mask] = 0.0 # Calculate the illumination at each face ambient, diffuse, specular = _apply_lighting( pixel_coords, pixel_normals, lights, cameras, materials ) colors = (ambient + diffuse) * texels + specular return colors
def sample_textures(self, fragments, **kwargs) -> torch.Tensor: """ Interpolate a 2D texture map using uv vertex texture coordinates for each face in the mesh. First interpolate the vertex uvs using barycentric coordinates for each pixel in the rasterized output. Then interpolate the texture map using the uv coordinate for each pixel. Args: fragments: The outputs of rasterization. From this we use - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordianates of each pixel relative to the faces (in the packed representation) which overlap the pixel. Returns: texels: tensor of shape (N, H, W, K, C) giving the interpolated texture for each pixel in the rasterized image. """ if self.isempty(): faces_verts_uvs = torch.zeros( (self._N, 3, 2), dtype=torch.float32, device=self.device ) else: packing_list = [ i[j] for i, j in zip(self.verts_uvs_list(), self.faces_uvs_list()) ] faces_verts_uvs = torch.cat(packing_list) texture_maps = self.maps_padded() # pixel_uvs: (N, H, W, K, 2) pixel_uvs = interpolate_face_attributes( fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs ) N, H_out, W_out, K = fragments.pix_to_face.shape N, H_in, W_in, C = texture_maps.shape # 3 for RGB # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) # textures.map: # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) texture_maps = ( texture_maps.permute(0, 3, 1, 2)[None, ...] .expand(K, -1, -1, -1, -1) .transpose(0, 1) .reshape(N * K, C, H_in, W_in) ) # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) # Now need to format the pixel uvs and the texture map correctly! # From pytorch docs, grid_sample takes `grid` and `input`: # grid specifies the sampling pixel locations normalized by # the input spatial dimensions It should have most # values in the range of [-1, 1]. Values x = -1, y = -1 # is the left-top pixel of input, and values x = 1, y = 1 is the # right-bottom pixel of input. pixel_uvs = pixel_uvs * 2.0 - 1.0 texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map if texture_maps.device != pixel_uvs.device: texture_maps = texture_maps.to(pixel_uvs.device) texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) # texels now has shape (NK, C, H_out, W_out) texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) return texels
def main(cfg: DictConfig): np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) obj_path = cfg.data.obj_path texture_path = cfg.data.texture_path views_folder = cfg.data.views_folder params_file = os.path.join(views_folder, "params.json") dataset = CowMultiViews(obj_path, views_folder, texture_path, params_file=params_file) train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset( dataset, train_fraction=0.7, validation_fraction=0.2) del dataset train_dataset.unit_normalize() validation_dataset.unit_normalize() test_dataset.unit_normalize() mesh_verts = test_dataset.get_verts() mesh_edges = test_dataset.get_edges() mesh_vert_normals = test_dataset.get_vert_normals() mesh_texture = test_dataset.get_texture() pytorch_mesh = test_dataset.pytorch_mesh.cuda() face_attrs = test_dataset.get_faces_as_vertex_matrices() feature_size = test_dataset.param_vectors.shape[1] torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda() torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda() torch_normals = torch.from_numpy( np.array(mesh_vert_normals)).float().cuda() torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda() torch_texture = torch.unsqueeze(torch_texture.permute(2, 0, 1), 0) torch_face_attrs = torch.from_numpy(np.array(face_attrs)).float().cuda() subset_indices = [82] #random.sample(list(range(len(test_dataset))),1) test_dataloader = Subset(test_dataset, subset_indices) print(subset_indices, len(test_dataloader)) image_translator = ImageTranslator(input_dim=6, output_dim=3, image_size=tuple( cfg.data.image_size)).cuda() mse_loss = torch.nn.MSELoss() # Initialize the optimizer. optimizer = torch.optim.Adam( image_translator.parameters(), lr=cfg.optimizer.lr, ) stats = None start_epoch = 0 checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) # Init the stats object. if stats is None: stats = Stats(["mse_loss", "sec/it"], ) # Learning rate scheduler setup. # Following the original code, we use exponential decay of the # learning rate: current_lr = base_lr * gamma ** (epoch / step_size) def lr_lambda(epoch): return cfg.optimizer.lr_scheduler_gamma**( epoch / cfg.optimizer.lr_scheduler_step_size) # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False) # Initialize the cache for storing variables needed for visulization. visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) # Init the visualization visdom env. if cfg.visualization.visdom: viz = Visdom( server=cfg.visualization.visdom_server, port=cfg.visualization.visdom_port, use_incoming_socket=False, ) else: viz = None loaded_data = torch.load(checkpoint_path) image_translator.load_state_dict(loaded_data["model"], strict=False) image_translator.eval() stats.new_epoch() image_list = [] for iteration, data in enumerate(test_dataloader): print(iteration) optimizer.zero_grad() views, param_vectors = data views = torch.unsqueeze(torch.from_numpy(views), 0) param_vectors = torch.unsqueeze(torch.from_numpy(param_vectors), 0) views = views.float().cuda() param_vectors = param_vectors.float().cuda() camera_instance = Camera() camera_instance.lookAt(param_vectors[0][0], math.degrees(param_vectors[0][1]), math.degrees(param_vectors[0][2])) rasterizer_instance = Rasterizer() rasterizer_instance.init_rasterizer(camera_instance.camera) fragments = rasterizer_instance.rasterizer(pytorch_mesh) pix_to_face = fragments.pix_to_face bary_coords = fragments.bary_coords pix_features = torch.squeeze( interpolate_face_attributes(pix_to_face, bary_coords, torch_face_attrs), 3) param_matrix = torch.zeros(pix_features.size()[0], pix_features.size()[1], pix_features.size()[2], param_vectors.size()[1]).float().cuda() param_matrix[:, :, :, :] = param_vectors image_features = pix_features # torch.cat([pix_features,param_matrix],3) predicted_render = image_translator(image_features, torch_texture) image_list = [ views[0].permute(2, 0, 1), predicted_render[0].permute(2, 0, 1) ] if viz is not None: visualize_image_outputs(validation_images=image_list, viz=viz, visdom_env=cfg.visualization.visdom_env)
def forward(self, fragments, meshes, **kwargs): colors = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, self.face_colors) return colors[:, :, :, 0]
def interpolate_texture_map(fragments, meshes) -> torch.Tensor: """ Interpolate a 2D texture map using uv vertex texture coordinates for each face in the mesh. First interpolate the vertex uvs using barycentric coordinates for each pixel in the rasterized output. Then interpolate the texture map using the uv coordinate for each pixel. Args: fragments: The outputs of rasterization. From this we use - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices of the faces (in the packed representation) which overlap each pixel in the image. - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying the barycentric coordianates of each pixel relative to the faces (in the packed representation) which overlap the pixel. meshes: Meshes representing a batch of meshes. It is expected that meshes has a textures attribute which is an instance of the Textures class. Returns: texels: tensor of shape (N, H, W, K, C) giving the interpolated texture for each pixel in the rasterized image. """ if not isinstance(meshes.textures, Textures): msg = "Expected meshes.textures to be an instance of Textures; got %r" raise ValueError(msg % type(meshes.textures)) faces_uvs = meshes.textures.faces_uvs_packed() verts_uvs = meshes.textures.verts_uvs_packed() faces_verts_uvs = verts_uvs[faces_uvs] texture_maps = meshes.textures.maps_padded() # pixel_uvs: (N, H, W, K, 2) pixel_uvs = interpolate_face_attributes(fragments.pix_to_face, fragments.bary_coords, faces_verts_uvs) N, H_out, W_out, K = fragments.pix_to_face.shape N, H_in, W_in, C = texture_maps.shape # 3 for RGB # pixel_uvs: (N, H, W, K, 2) -> (N, K, H, W, 2) -> (NK, H, W, 2) pixel_uvs = pixel_uvs.permute(0, 3, 1, 2, 4).reshape(N * K, H_out, W_out, 2) # textures.map: # (N, H, W, C) -> (N, C, H, W) -> (1, N, C, H, W) # -> expand (K, N, C, H, W) -> reshape (N*K, C, H, W) texture_maps = (texture_maps.permute(0, 3, 1, 2)[None, ...].expand( K, -1, -1, -1, -1).transpose(0, 1).reshape(N * K, C, H_in, W_in)) # Textures: (N*K, C, H, W), pixel_uvs: (N*K, H, W, 2) # Now need to format the pixel uvs and the texture map correctly! # From pytorch docs, grid_sample takes `grid` and `input`: # grid specifies the sampling pixel locations normalized by # the input spatial dimensions It should have most # values in the range of [-1, 1]. Values x = -1, y = -1 # is the left-top pixel of input, and values x = 1, y = 1 is the # right-bottom pixel of input. pixel_uvs = pixel_uvs * 2.0 - 1.0 texture_maps = torch.flip(texture_maps, [2]) # flip y axis of the texture map if texture_maps.device != pixel_uvs.device: texture_maps = texture_maps.to(pixel_uvs.device) texels = F.grid_sample(texture_maps, pixel_uvs, align_corners=False) texels = texels.reshape(N, K, C, H_out, W_out).permute(0, 3, 4, 1, 2) return texels
def main(cfg: DictConfig): np.random.seed(cfg.seed) torch.manual_seed(cfg.seed) obj_path = cfg.data.obj_path texture_path = cfg.data.texture_path views_folder = cfg.data.views_folder params_file = os.path.join(views_folder,"params.json") dataset = CowMultiViews(obj_path,views_folder,texture_path,params_file=params_file) train_dataset, validation_dataset, test_dataset = CowMultiViews.random_split_dataset(dataset, train_fraction=0.7, validation_fraction=0.2) del dataset train_dataset.unit_normalize() validation_dataset.unit_normalize() mesh_verts = train_dataset.get_verts() mesh_edges = train_dataset.get_edges() mesh_vert_normals = train_dataset.get_vert_normals() mesh_texture = train_dataset.get_texture() pytorch_mesh = train_dataset.pytorch_mesh.cuda() random_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['random'],num_random_dims=cfg.training.feature_dim) coord_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['coord'],num_random_dims=cfg.training.feature_dim) normal_face_attrs = train_dataset.get_faces_as_vertex_matrices(features_list=['normal'],num_random_dims=cfg.training.feature_dim) torch_verts = torch.from_numpy(np.array(mesh_verts)).float().cuda() torch_edges = torch.from_numpy(np.array(mesh_edges)).long().cuda() torch_normals = torch.from_numpy(np.array(mesh_vert_normals)).float().cuda() torch_texture = torch.from_numpy(np.array(mesh_texture)).float().cuda() torch_texture = torch.unsqueeze(torch_texture,0) torch_random_face_attrs = torch.tensor(np.array(random_face_attrs),requires_grad=True).float().cuda() torch_random_face_attrs = torch.nn.Parameter(torch_random_face_attrs) torch_coord_face_attrs = torch.tensor(np.array(coord_face_attrs)).float().cuda() torch_normal_face_attrs = torch.tensor(np.array(normal_face_attrs)).float().cuda() train_dataloader = DataLoader(train_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4) validation_dataloader = DataLoader(validation_dataset,batch_size=cfg.training.batch_size,shuffle=True,num_workers=4) image_translator = ImageTranslator(input_dim=cfg.training.feature_dim+9,output_dim=3, image_size=tuple(cfg.data.image_size)).cuda() mse_loss = torch.nn.MSELoss() # Initialize the optimizer. optimizer = torch.optim.Adam( list(image_translator.parameters())+[torch_random_face_attrs], lr=cfg.optimizer.lr, ) stats = None start_epoch = 0 checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path) # Init the stats object. if stats is None: stats = Stats( ["mse_loss", "sec/it"], ) # Learning rate scheduler setup. # Following the original code, we use exponential decay of the # learning rate: current_lr = base_lr * gamma ** (epoch / step_size) def lr_lambda(epoch): return cfg.optimizer.lr_scheduler_gamma ** ( epoch / cfg.optimizer.lr_scheduler_step_size ) # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler. lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False ) # Initialize the cache for storing variables needed for visulization. visuals_cache = collections.deque(maxlen=cfg.visualization.history_size) # Init the visualization visdom env. if cfg.visualization.visdom: viz = Visdom( server=cfg.visualization.visdom_server, port=cfg.visualization.visdom_port, use_incoming_socket=False, ) else: viz = None for epoch in range(cfg.optimizer.max_epochs): image_translator.train() stats.new_epoch() for iteration,data in enumerate(train_dataloader): optimizer.zero_grad() views,param_vectors = data views = views.float().cuda() param_vectors = param_vectors.float().cuda() camera_instance = Camera() camera_instance.lookAt(param_vectors[0][0],math.degrees(param_vectors[0][1]),math.degrees(param_vectors[0][2])) camera_location = param_vectors[0,3:6] light_location = param_vectors[0,6:9] torch_camera_face_attrs = torch_coord_face_attrs - camera_location torch_light_face_attrs = torch_coord_face_attrs - light_location torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2) rasterizer_instance = Rasterizer() rasterizer_instance.init_rasterizer(camera_instance.camera) fragments = rasterizer_instance.rasterizer(pytorch_mesh) pix_to_face = fragments.pix_to_face bary_coords = fragments.bary_coords pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face,bary_coords,torch_face_attrs),3) predicted_render = image_translator(pix_features,torch_texture) loss = 1000*mse_loss(predicted_render,views) loss.backward() optimizer.step() # Update stats with the current metrics. stats.update( {"mse_loss": float(loss)}, stat_set="train", ) if iteration % cfg.stats_print_interval == 0: stats.print(stat_set="train") # Adjust the learning rate. #lr_scheduler.step() # Validation if epoch % cfg.validation_epoch_interval == 0: # and epoch > 0: # Sample a validation camera/image. val_batch = next(validation_dataloader.__iter__()) views, param_vectors= val_batch views = views.float().cuda() param_vectors = param_vectors.float().cuda() # Activate eval mode of the model (allows to do a full rendering pass). image_translator.eval() with torch.no_grad(): camera_instance = Camera() camera_instance.lookAt(param_vectors[0][0], math.degrees(param_vectors[0][1]), math.degrees(param_vectors[0][2])) camera_location = param_vectors[0,3:6] light_location = param_vectors[0,6:9] torch_camera_face_attrs = torch_coord_face_attrs - camera_location torch_light_face_attrs = torch_coord_face_attrs - light_location torch_face_attrs = torch.cat([torch_camera_face_attrs,torch_normal_face_attrs,torch_light_face_attrs,torch_random_face_attrs],2) rasterizer_instance = Rasterizer() rasterizer_instance.init_rasterizer(camera_instance.camera) fragments = rasterizer_instance.rasterizer(pytorch_mesh) pix_to_face = fragments.pix_to_face bary_coords = fragments.bary_coords pix_features = torch.squeeze(interpolate_face_attributes(pix_to_face, bary_coords, torch_face_attrs), 3) #pix_features = pix_features.permute(0, 3, 1, 2) predicted_render = image_translator(pix_features,torch_texture) loss = 1000*mse_loss(predicted_render,views) # Update stats with the validation metrics. stats.update({"mse_loss":loss}, stat_set="val") stats.print(stat_set="val") if viz is not None: # Plot that loss curves into visdom. stats.plot_stats( viz=viz, visdom_env=cfg.visualization.visdom_env, plot_file=None, ) # Visualize the intermediate results. render_max = torch.max(predicted_render) visualize_image_outputs( validation_images = [views[0].permute(2,0,1),predicted_render[0].permute(2,0,1)],viz=viz,visdom_env=cfg.visualization.visdom_env ) # Set the model back to train mode. image_translator.train() # Checkpoint. if ( epoch % cfg.checkpoint_epoch_interval == 0 and len(cfg.checkpoint_path) > 0 and epoch > 0 ): print(f"Storing checkpoint {checkpoint_path}.") data_to_store = { "model": image_translator.state_dict(), "features" : torch_face_attrs, "optimizer": optimizer.state_dict(), "stats": pickle.dumps(stats), } torch.save(data_to_store, checkpoint_path)