def compute_loss(self, data, eval_mode=False, it=None): ''' Compute the loss. Args: data (dict): data dictionary eval_mode (bool): whether to use eval mode it (int): training iteration ''' # Initialize loss dictionary and other values loss = {} n_points = self.n_eval_points if eval_mode else self.n_training_points # Process data dictionary (img, mask_img, depth_img, world_mat, camera_mat, scale_mat, inputs, sparse_depth) = self.process_data_dict(data) # Shortcuts device = self.device patch_size = self.patch_size reduction_method = self.reduction_method batch_size, _, h, w = img.shape # Assertions assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0) and (n_points > 0)) # Sample points on image plane ("pixels") if n_points >= h * w: p = arange_pixels((h, w), batch_size)[1].to(device) else: p = sample_patch_points( batch_size, n_points, patch_size=patch_size, image_resolution=(h, w), continuous=self.sample_continuous, ).to(device) # Apply losses # 1.) Get Object Mask values and define masks for losses mask_gt = get_tensor_values(mask_img, p, squeeze_channel_dim=True).bool() # Calculate 3D points which need to be evaluated for the occupancy and # freespace loss p_freespace = get_freespace_loss_points(p, camera_mat, world_mat, scale_mat, self.use_cube_intersection, self.depth_range) depth_input = depth_img if (self.lambda_depth != 0 or self.depth_from_visual_hull) else None p_occupancy = get_occupancy_loss_points(p, camera_mat, world_mat, scale_mat, depth_input, self.use_cube_intersection, self.occupancy_random_normal, self.depth_range) # 2.) Initialize loss loss['loss'] = 0 # 3.) Make forward pass through the network and obtain predictions # with masks (p_world_hat, rgb_pred, logits_occupancy, logits_freespace, mask_pred, p_world_hat_sparse, mask_pred_sparse, normals) = self.model(p, p_occupancy, p_freespace, inputs, camera_mat, world_mat, scale_mat, it, sparse_depth, self.lambda_normal != 0) # 4.) Calculate Loss # 4.1) Photo Consistency Loss mask_rgb = mask_pred & mask_gt self.calc_photoconsistency_loss(mask_rgb, rgb_pred, img, p, reduction_method, loss, patch_size, eval_mode) # 4.2) Calculate Depth Loss mask_depth = mask_pred & mask_gt self.calc_depth_loss(mask_depth, depth_img, p, camera_mat, world_mat, scale_mat, p_world_hat, reduction_method, loss, eval_mode) # 4 3 Calculate normal loss self.calc_normal_loss(normals, batch_size, loss, eval_mode) # 4.4) Sparse Depth Loss self.calc_sparse_depth_loss(sparse_depth, p_world_hat_sparse, mask_pred_sparse, reduction_method, loss, eval_mode) # 4.5) Freespace loss mask_freespace = (mask_gt == 0) if self.always_freespace else \ ((mask_gt == 0) & (mask_pred)) self.calc_freespace_loss(logits_freespace, mask_freespace, reduction_method, loss) # 4.6) Occupancy Loss mask_occupancy = (mask_pred == 0) & mask_gt self.calc_occupancy_loss(logits_occupancy, mask_occupancy, reduction_method, loss) # Save mean mask intersection for tensorboard if eval_mode: self.calc_mask_intersection(mask_gt, mask_pred, loss) return loss if eval_mode else loss['loss']
def render_img(self, camera_mat, world_mat, inputs, scale_mat=None, c=None, stats_dict={}, resolution=(128, 128)): ''' Renders an image for provided camera information. Args: camera_mat (tensor): camera matrix world_mat (tensor): world matrix scale_mat (tensor): scale matrix c (tensor): latent conditioned code c stats_dict (dict): statistics dictionary resolution (tuple): output image resolution ''' device = self.device h, w = resolution t0 = time.time() p_loc, pixels = arange_pixels(resolution=(h, w)) pixels = pixels.to(device) stats_dict['time_prepare_points'] = time.time() - t0 if self.colors in ('rgb', 'depth'): # Get predicted world points with torch.no_grad(): t0 = time.time() p_world_hat, mask_pred, mask_zero_occupied = \ self.model.pixels_to_world( pixels, camera_mat, world_mat, scale_mat, c, sampling_accuracy=self.sampling_accuracy) stats_dict['time_eval_depth'] = time.time() - t0 t0 = time.time() p_loc = p_loc[mask_pred] with torch.no_grad(): if self.colors == 'rgb': img_out = (255 * np.ones((h, w, 3))).astype(np.uint8) t0 = time.time() if mask_pred.sum() > 0: rgb_hat = self.model.decode_color(p_world_hat, c=c) rgb_hat = rgb_hat[mask_pred].cpu().numpy() rgb_hat = (rgb_hat * 255).astype(np.uint8) img_out[p_loc[:, 1], p_loc[:, 0]] = rgb_hat img_out = Image.fromarray(img_out).convert('RGB') elif self.colors == 'depth': img_out = (255 * np.ones((h, w))).astype(np.uint8) if mask_pred.sum() > 0: p_world_hat = p_world_hat[mask_pred].unsqueeze(0) d_values = transform_to_camera_space( p_world_hat, camera_mat, world_mat, scale_mat).squeeze(0)[:, -1].cpu().numpy() m = d_values[d_values != np.inf].min() M = d_values[d_values != np.inf].max() d_values = 0.5 + 0.45 * (d_values - m) / (M - m) d_image_values = d_values * 255 img_out[p_loc[:, 1], p_loc[:, 0]] = \ d_image_values.astype(np.uint8) img_out = Image.fromarray(img_out).convert("L") stats_dict['time_eval_color'] = time.time() - t0 return img_out
def test_dataset(self): # 1. rerender input point clouds / meshes using the saved camera_mat # compare mask image with saved mask image # 2. backproject masked points to space with dense depth map, # fuse all views and save batch_size = 1 device = torch.device('cuda:0') data_dir = 'data/synthetic/cube_mesh' output_dir = os.path.join('tests', 'outputs', 'test_data') if not os.path.isdir(output_dir): os.makedirs(output_dir) # dataset dataset = MVRDataset(data_dir=data_dir, load_dense_depth=True, mode="train") data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False) meshes = load_objs_as_meshes([os.path.join(data_dir, 'mesh.obj')]).to(device) cams = dataset.get_cameras().to(device) image_size = imageio.imread(dataset.image_files[0]).shape[0] # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc raster_settings = RasterizationSettings( image_size=image_size, blur_radius=0.0, faces_per_pixel=5, bin_size= None, # this setting controls whether naive or coarse-to-fine rasterization is used max_faces_per_bin=None # this setting is for coarse rasterization ) rasterizer = MeshRasterizer(cameras=None, raster_settings=raster_settings) # render with loaded cameras positions and training tranformation functions pixel_world_all = [] for idx, data in enumerate(data_loader): # get datas img = data.get('img.rgb').to(device) assert (img.min() >= 0 and img.max() <= 1 ), "Image must be a floating number between 0 and 1." mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1) camera_mat = data['camera_mat'].to(device) cams.R, cams.T = decompose_to_R_and_t(camera_mat) cams._N = cams.R.shape[0] cams.to(device) self.assertTrue( torch.equal(cams.get_world_to_view_transform().get_matrix(), camera_mat)) # transform to view and rerender with non-rotated camera verts_padded = transform_to_camera_space(meshes.verts_padded(), cams) meshes_in_view = meshes.offset_verts( -meshes.verts_packed() + padded_to_packed( verts_padded, meshes.mesh_to_verts_packed_first_idx(), meshes.verts_packed().shape[0])) fragments = rasterizer(meshes_in_view, cameras=dataset.get_cameras().to(device)) # compare mask mask = fragments.pix_to_face[..., :1] >= 0 imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx), mask[0, ...].cpu().to(dtype=torch.uint8) * 255) # allow 5 pixels difference self.assertTrue(torch.sum(mask_gt != mask) < 5) # check dense maps # backproject points to the world pixel range (-1, 1) pixels = arange_pixels((image_size, image_size), batch_size)[1].to(device) depth_img = data.get('img.depth').to(device) # get the depth and mask at the sampled pixel position depth_gt = get_tensor_values(depth_img, pixels, squeeze_channel_dim=True) mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(), pixels, squeeze_channel_dim=True).bool() # get pixels and depth inside the masked area pixels_packed = pixels[mask_gt] depth_gt_packed = depth_gt[mask_gt] first_idx = torch.zeros((pixels.shape[0], ), device=device, dtype=torch.long) num_pts_in_mask = mask_gt.sum(dim=1) first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1] pixels_padded = packed_to_padded(pixels_packed, first_idx, num_pts_in_mask.max().item()) depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx, num_pts_in_mask.max().item()) # backproject to world coordinates # contains nan and infinite values due to depth_gt_padded containing 0.0 pixel_world_padded = transform_to_world(pixels_padded, depth_gt_padded[..., None], cams) # transform back to list, containing no padded values split_size = num_pts_in_mask[..., None].repeat(1, 2) split_size[:, 1] = 3 pixel_world_list = padded_to_list(pixel_world_padded, split_size) pixel_world_all.extend(pixel_world_list) idx += 1 if idx >= 10: break pixel_world_all = torch.cat(pixel_world_all, dim=0) mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(), faces=None, process=False) mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))