def calc_depth_loss(self, mask_depth, depth_img, pixels, camera_mat, world_mat, scale_mat, p_world_hat, reduction_method, loss={}, eval_mode=False): ''' Calculates the depth loss. Args: mask_depth (tensor): mask for depth loss depth_img (tensor): depth image pixels (tensor): sampled pixels in range [-1, 1] camera_mat (tensor): camera matrix world_mat (tensor): world matrix scale_mat (tensor): scale matrix p_world_hat (tensor): predicted world points reduction_method (string): how to reduce the loss tensor loss (dict): loss dictionary eval_mode (bool): whether to use eval mode ''' if self.lambda_depth != 0 and mask_depth.sum() > 0: batch_size, n_pts, _ = p_world_hat.shape loss_depth_val = torch.tensor(10) # For depth values, we have to check again if all values are valid # as we potentially train with sparse depth maps depth_gt, mask_gt_depth = get_tensor_values( depth_img, pixels, squeeze_channel_dim=True, with_mask=True) mask_depth &= mask_gt_depth if self.depth_loss_on_world_points: # Applying L2 loss on world points results in the same as # applying L1 on the depth values with scaling (see Sup. Mat.) p_world = transform_to_world( pixels, depth_gt.unsqueeze(-1), camera_mat, world_mat, scale_mat) loss_depth = losses.l2_loss( p_world_hat[mask_depth], p_world[mask_depth], reduction_method) * self.lambda_depth / batch_size if eval_mode: loss_depth_val = losses.l2_loss( p_world_hat[mask_depth], p_world[mask_depth], 'mean') * self.lambda_depth else: d_pred = transform_to_camera_space( p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1] loss_depth = losses.l1_loss( d_pred[mask_depth], depth_gt[mask_depth], reduction_method, feat_dim=False) * \ self.lambda_depth / batch_size if eval_mode: loss_depth_val = losses.l1_loss( d_pred[mask_depth], depth_gt[mask_depth], 'mean', feat_dim=False) * self.lambda_depth loss['loss'] += loss_depth loss['loss_depth'] = loss_depth if eval_mode: loss['loss_depth_eval'] = loss_depth_val
def calc_photoconsistency_loss(self, mask_rgb, rgb_pred, img, pixels, reduction_method, loss, patch_size, eval_mode=False): ''' Calculates the photo-consistency loss. Args: mask_rgb (tensor): mask for photo-consistency loss rgb_pred (tensor): predicted rgb color values img (tensor): GT image pixels (tensor): sampled pixels in range [-1, 1] reduction_method (string): how to reduce the loss tensor loss (dict): loss dictionary patch_size (int): size of sampled patch eval_mode (bool): whether to use eval mode ''' if self.lambda_rgb != 0 and mask_rgb.sum() > 0: batch_size, n_pts, _ = rgb_pred.shape loss_rgb_eval = torch.tensor(3) # Get GT RGB values rgb_gt = get_tensor_values(img, pixels) # 3.1) Calculate RGB Loss loss_rgb = losses.l1_loss( rgb_pred[mask_rgb], rgb_gt[mask_rgb], reduction_method) * self.lambda_rgb / batch_size loss['loss'] += loss_rgb loss['loss_rgb'] = loss_rgb if eval_mode: loss_rgb_eval = losses.l1_loss( rgb_pred[mask_rgb], rgb_gt[mask_rgb], 'mean') * \ self.lambda_rgb # 3.2) Image Gradient loss if self.lambda_image_gradients != 0: assert (patch_size > 1) loss_grad = losses.image_gradient_loss( rgb_pred, rgb_gt, mask_rgb, patch_size, reduction_method) * \ self.lambda_image_gradients / batch_size loss['loss'] += loss_grad loss['loss_image_gradient'] = loss_grad if eval_mode: loss['loss_rgb_eval'] = loss_rgb_eval
rgb_out = [] for i, imgp in enumerate(img_list): # load files img = np.array(Image.open(imgp).convert("RGB")).astype(np.float32) / 255 h, w, c = img.shape depth = np.array(imageio.imread(depth_list[i])) depth = depth.reshape(depth.shape[0], depth.shape[1], -1)[..., 0] hd, wd = depth.shape assert(h == hd and w == wd) p = sample_patch_points(1, n_sample_points, patch_size=1, image_resolution=(h, w), continuous=False) img = torch.from_numpy(img).permute(2, 0, 1).unsqueeze(0) depth = torch.from_numpy(depth).unsqueeze(0).unsqueeze(1) rgb = get_tensor_values(img, p) d = get_tensor_values(depth, p) mask = (d != np.inf).squeeze(-1) d = d[mask].unsqueeze(0) rgb = rgb[mask] p = p[mask].unsqueeze(0) # transform to world cm = cam.get('camera_mat_%d' % i).astype(np.float32).reshape(1, 4, 4) wm = cam.get('world_mat_%d' % i).astype(np.float32).reshape(1, 4, 4) sm = cam.get('scale_mat_%d' % i, np.eye(4)).astype(np.float32).reshape(1, 4, 4) p_world = transform_to_world(p, d, cm, wm, sm)[0] v_out.append(p_world) v = np.concatenate(v_out, axis=0) mesh = trimesh.Trimesh(vertices=v).export('out/pcd.ply')
def compute_loss(self, data, eval_mode=False, it=None): ''' Compute the loss. Args: data (dict): data dictionary eval_mode (bool): whether to use eval mode it (int): training iteration ''' # Initialize loss dictionary and other values loss = {} n_points = self.n_eval_points if eval_mode else self.n_training_points # Process data dictionary (img, mask_img, depth_img, world_mat, camera_mat, scale_mat, inputs, sparse_depth) = self.process_data_dict(data) # Shortcuts device = self.device patch_size = self.patch_size reduction_method = self.reduction_method batch_size, _, h, w = img.shape # Assertions assert (((h, w) == mask_img.shape[2:4]) and (patch_size > 0) and (n_points > 0)) # Sample points on image plane ("pixels") if n_points >= h * w: p = arange_pixels((h, w), batch_size)[1].to(device) else: p = sample_patch_points( batch_size, n_points, patch_size=patch_size, image_resolution=(h, w), continuous=self.sample_continuous, ).to(device) # Apply losses # 1.) Get Object Mask values and define masks for losses mask_gt = get_tensor_values(mask_img, p, squeeze_channel_dim=True).bool() # Calculate 3D points which need to be evaluated for the occupancy and # freespace loss p_freespace = get_freespace_loss_points(p, camera_mat, world_mat, scale_mat, self.use_cube_intersection, self.depth_range) depth_input = depth_img if (self.lambda_depth != 0 or self.depth_from_visual_hull) else None p_occupancy = get_occupancy_loss_points(p, camera_mat, world_mat, scale_mat, depth_input, self.use_cube_intersection, self.occupancy_random_normal, self.depth_range) # 2.) Initialize loss loss['loss'] = 0 # 3.) Make forward pass through the network and obtain predictions # with masks (p_world_hat, rgb_pred, logits_occupancy, logits_freespace, mask_pred, p_world_hat_sparse, mask_pred_sparse, normals) = self.model(p, p_occupancy, p_freespace, inputs, camera_mat, world_mat, scale_mat, it, sparse_depth, self.lambda_normal != 0) # 4.) Calculate Loss # 4.1) Photo Consistency Loss mask_rgb = mask_pred & mask_gt self.calc_photoconsistency_loss(mask_rgb, rgb_pred, img, p, reduction_method, loss, patch_size, eval_mode) # 4.2) Calculate Depth Loss mask_depth = mask_pred & mask_gt self.calc_depth_loss(mask_depth, depth_img, p, camera_mat, world_mat, scale_mat, p_world_hat, reduction_method, loss, eval_mode) # 4 3 Calculate normal loss self.calc_normal_loss(normals, batch_size, loss, eval_mode) # 4.4) Sparse Depth Loss self.calc_sparse_depth_loss(sparse_depth, p_world_hat_sparse, mask_pred_sparse, reduction_method, loss, eval_mode) # 4.5) Freespace loss mask_freespace = (mask_gt == 0) if self.always_freespace else \ ((mask_gt == 0) & (mask_pred)) self.calc_freespace_loss(logits_freespace, mask_freespace, reduction_method, loss) # 4.6) Occupancy Loss mask_occupancy = (mask_pred == 0) & mask_gt self.calc_occupancy_loss(logits_occupancy, mask_occupancy, reduction_method, loss) # Save mean mask intersection for tensorboard if eval_mode: self.calc_mask_intersection(mask_gt, mask_pred, loss) return loss if eval_mode else loss['loss']
def test_dataset(self): # 1. rerender input point clouds / meshes using the saved camera_mat # compare mask image with saved mask image # 2. backproject masked points to space with dense depth map, # fuse all views and save batch_size = 1 device = torch.device('cuda:0') data_dir = 'data/synthetic/cube_mesh' output_dir = os.path.join('tests', 'outputs', 'test_data') if not os.path.isdir(output_dir): os.makedirs(output_dir) # dataset dataset = MVRDataset(data_dir=data_dir, load_dense_depth=True, mode="train") data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, num_workers=0, shuffle=False) meshes = load_objs_as_meshes([os.path.join(data_dir, 'mesh.obj')]).to(device) cams = dataset.get_cameras().to(device) image_size = imageio.imread(dataset.image_files[0]).shape[0] # initialize rasterizer, we check mask pngs only, so no need to create lights and shaders etc raster_settings = RasterizationSettings( image_size=image_size, blur_radius=0.0, faces_per_pixel=5, bin_size= None, # this setting controls whether naive or coarse-to-fine rasterization is used max_faces_per_bin=None # this setting is for coarse rasterization ) rasterizer = MeshRasterizer(cameras=None, raster_settings=raster_settings) # render with loaded cameras positions and training tranformation functions pixel_world_all = [] for idx, data in enumerate(data_loader): # get datas img = data.get('img.rgb').to(device) assert (img.min() >= 0 and img.max() <= 1 ), "Image must be a floating number between 0 and 1." mask_gt = data.get('img.mask').to(device).permute(0, 2, 3, 1) camera_mat = data['camera_mat'].to(device) cams.R, cams.T = decompose_to_R_and_t(camera_mat) cams._N = cams.R.shape[0] cams.to(device) self.assertTrue( torch.equal(cams.get_world_to_view_transform().get_matrix(), camera_mat)) # transform to view and rerender with non-rotated camera verts_padded = transform_to_camera_space(meshes.verts_padded(), cams) meshes_in_view = meshes.offset_verts( -meshes.verts_packed() + padded_to_packed( verts_padded, meshes.mesh_to_verts_packed_first_idx(), meshes.verts_packed().shape[0])) fragments = rasterizer(meshes_in_view, cameras=dataset.get_cameras().to(device)) # compare mask mask = fragments.pix_to_face[..., :1] >= 0 imageio.imwrite(os.path.join(output_dir, "mask_%06d.png" % idx), mask[0, ...].cpu().to(dtype=torch.uint8) * 255) # allow 5 pixels difference self.assertTrue(torch.sum(mask_gt != mask) < 5) # check dense maps # backproject points to the world pixel range (-1, 1) pixels = arange_pixels((image_size, image_size), batch_size)[1].to(device) depth_img = data.get('img.depth').to(device) # get the depth and mask at the sampled pixel position depth_gt = get_tensor_values(depth_img, pixels, squeeze_channel_dim=True) mask_gt = get_tensor_values(mask.permute(0, 3, 1, 2).float(), pixels, squeeze_channel_dim=True).bool() # get pixels and depth inside the masked area pixels_packed = pixels[mask_gt] depth_gt_packed = depth_gt[mask_gt] first_idx = torch.zeros((pixels.shape[0], ), device=device, dtype=torch.long) num_pts_in_mask = mask_gt.sum(dim=1) first_idx[1:] = num_pts_in_mask.cumsum(dim=0)[:-1] pixels_padded = packed_to_padded(pixels_packed, first_idx, num_pts_in_mask.max().item()) depth_gt_padded = packed_to_padded(depth_gt_packed, first_idx, num_pts_in_mask.max().item()) # backproject to world coordinates # contains nan and infinite values due to depth_gt_padded containing 0.0 pixel_world_padded = transform_to_world(pixels_padded, depth_gt_padded[..., None], cams) # transform back to list, containing no padded values split_size = num_pts_in_mask[..., None].repeat(1, 2) split_size[:, 1] = 3 pixel_world_list = padded_to_list(pixel_world_padded, split_size) pixel_world_all.extend(pixel_world_list) idx += 1 if idx >= 10: break pixel_world_all = torch.cat(pixel_world_all, dim=0) mesh = trimesh.Trimesh(vertices=pixel_world_all.cpu(), faces=None, process=False) mesh.export(os.path.join(output_dir, 'pixel_to_world.ply'))