def forward(self, batch, return_logs=False, progress=0.0, **kwargs): """ Processes a batch. Parameters ---------- batch : dict Input batch return_logs : bool True if logs are stored progress : Training progress percentage Returns ------- output : dict Dictionary containing a "loss" scalar and different metrics and predictions for logging and downstream usage. """ if not self.training: # If not training, no need for self-supervised loss return SfmModel.forward(self, batch, return_logs=return_logs, **kwargs) else: if self.supervised_loss_weight == 1.: # If no self-supervision, no need to calculate loss self_sup_output = SfmModel.forward(self, batch, return_logs=return_logs, **kwargs) loss = torch.tensor([0.]).type_as(batch['rgb']) else: # Otherwise, calculate and weight self-supervised loss self_sup_output = SelfSupModel.forward( self, batch, return_logs=return_logs, progress=progress, **kwargs) loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss'] # Calculate and weight supervised loss sup_output = self.supervised_loss( self_sup_output['inv_depths'], depth2inv(batch['depth']), return_logs=return_logs, progress=progress) loss += self.supervised_loss_weight * sup_output['loss'] if 'inv_depths_rgbd' in self_sup_output: sup_output2 = self.supervised_loss( self_sup_output['inv_depths_rgbd'], depth2inv(batch['depth']), return_logs=return_logs, progress=progress) loss += self.weight_rgbd * self.supervised_loss_weight * sup_output2['loss'] if 'depth_loss' in self_sup_output: loss += self_sup_output['depth_loss'] # Merge and return outputs return { 'loss': loss, **merge_outputs(self_sup_output, sup_output), }
def forward(self, batch, return_logs=False, progress=0.0): """ Processes a batch. Parameters ---------- batch : dict Input batch return_logs : bool True if logs are stored progress : Training progress percentage Returns ------- output : dict Dictionary containing a "loss" scalar and different metrics and predictions for logging and downstream usage. """ if not self.training: # If not training, no need for self-supervised loss return SfmModel.forward(self, batch) else: # Calculate predicted depth and pose output output = super().forward(batch, return_logs=return_logs) # Introduce poses ground_truth poses_gt = [[], []] poses_gt[0], poses_gt[1] = torch.zeros((6, 4, 4)), torch.zeros((6, 4, 4)) for i in range(6): poses_gt[0][i] = batch['pose_context'][0][i].inverse() @ batch['pose'][i] poses_gt[1][i] = batch['pose_context'][1][i].inverse() @ batch['pose'][i] poses_gt = [Pose(poses_gt[0]), Pose(poses_gt[1])] multiview_loss = self.multiview_photometric_loss( batch['rgb_original'], batch['rgb_context_original'], output['inv_depths'], poses_gt, batch['intrinsics'], batch['extrinsics'], return_logs=return_logs, progress=progress) # loss = multiview_loss['loss'] loss = 0. # Calculate supervised loss supervision_loss = self.supervised_loss(output['inv_depths'], depth2inv(batch['depth']), return_logs=return_logs, progress=progress) loss += self.supervised_loss_weight * supervision_loss['loss'] # Return loss and metrics return { 'loss': loss, **merge_outputs(merge_outputs(multiview_loss, supervision_loss), output) }
def forward(self, image, context, inv_depths, poses, path_to_ego_mask, path_to_ego_mask_context, K, ref_K, extrinsics, ref_extrinsics, context_type, return_logs=False, progress=0.0): """ Calculates training photometric loss. Parameters ---------- image : torch.Tensor [B,3,H,W] Original image context : list of torch.Tensor [B,3,H,W] Context containing a list of reference images inv_depths : list of torch.Tensor [B,1,H,W] Predicted depth maps for the original image, in all scales K : torch.Tensor [B,3,3] Original camera intrinsics ref_K : torch.Tensor [B,3,3] Reference camera intrinsics poses : list of Pose Camera transformation between original and context return_logs : bool True if logs are saved for visualization progress : float Training percentage Returns ------- losses_and_metrics : dict Output dictionary """ # If using progressive scaling self.n = self.progressive_scaling(progress) # Loop over all reference images photometric_losses = [[] for _ in range(self.n)] images = match_scales(image, inv_depths, self.n) inv_depths2 = [] for k in range(len(inv_depths)): Btmp, C, H, W = inv_depths[k].shape depths2 = torch.zeros_like(inv_depths[k]) for i in range(H): for j in range(W): depths2[ 0, 0, i, j] = 2.0 #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20 inv_depths2.append(depth2inv(depths2)) if is_list(context_type[0][0]): n_context = len(context_type[0]) else: n_context = len(context_type) #n_context = len(context) device = image.get_device() B = len(path_to_ego_mask) if is_list(path_to_ego_mask[0]): H_full, W_full = np.load(path_to_ego_mask[0][0]).shape else: H_full, W_full = np.load(path_to_ego_mask[0]).shape # getting ego masks for target and source cameras # fullsize mask ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device) ref_ego_mask_tensor = [] for i_context in range(n_context): ref_ego_mask_tensor.append( torch.ones(B, 1, H_full, W_full).to(device)) for b in range(B): if self.mask_ego: if is_list(path_to_ego_mask[b]): ego_mask_tensor[b, 0] = torch.from_numpy( np.load(path_to_ego_mask[b][0])).float() else: ego_mask_tensor[b, 0] = torch.from_numpy( np.load(path_to_ego_mask[b])).float() for i_context in range(n_context): if is_list(path_to_ego_mask_context[0][0]): paths_context_ego = [ p[i_context][0] for p in path_to_ego_mask_context ] else: paths_context_ego = path_to_ego_mask_context[i_context] ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy( np.load(paths_context_ego[b])).float() # resized masks ego_mask_tensors = [] ref_ego_mask_tensors = [] for i_context in range(n_context): ref_ego_mask_tensors.append([]) for i in range(self.n): Btmp, C, H, W = images[i].shape if W < W_full: #inv_scale_factor = int(W_full / W) #print(W_full / W) #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor)) ego_mask_tensors.append( interpolate_image(ego_mask_tensor, shape=(Btmp, 1, H, W), mode='nearest', align_corners=None)) for i_context in range(n_context): #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])) ref_ego_mask_tensors[i_context].append( interpolate_image(ref_ego_mask_tensor[i_context], shape=(Btmp, 1, H, W), mode='nearest', align_corners=None)) else: ego_mask_tensors.append(ego_mask_tensor) for i_context in range(n_context): ref_ego_mask_tensors[i_context].append( ref_ego_mask_tensor[i_context]) for i_context in range(n_context): _, C, H, W = context[i_context].shape if W < W_full: inv_scale_factor = int(W_full / W) #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]) ref_ego_mask_tensor[i_context] = interpolate_image( ref_ego_mask_tensor[i_context], shape=(Btmp, 1, H, W), mode='nearest', align_corners=None) print(ref_extrinsics) for j, (ref_image, pose) in enumerate(zip(context, poses)): print(ref_extrinsics[j]) ref_context_type = [c[j][0] for c in context_type] if is_list( context_type[0][0]) else context_type[j] print(ref_context_type) print(pose.mat) # Calculate warped images ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image( inv_depths2, ref_image, ref_ego_mask_tensor[j], K, ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type) print(pose.mat) # Calculate and store image loss photometric_loss = self.calc_photometric_loss(ref_warped, images) tt = str(int(time.time() % 10000)) for i in range(self.n): B, C, H, W = images[i].shape for b in range(B): orig_PIL_0 = torch.transpose( (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() orig_PIL = torch.transpose( (ref_image[b, :, :, :] * ref_ego_mask_tensor[j][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() warped_PIL_0 = torch.transpose( (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() warped_PIL = torch.transpose( (ref_warped[i][b, :, :, :] * ego_mask_tensors[i][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() target_PIL_0 = torch.transpose( (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() target_PIL = torch.transpose( (images[i][b, :, :, :] * ego_mask_tensors[i][b, :, :, :] ).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy() cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png', orig_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png', orig_PIL * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png', warped_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png', warped_PIL * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png', target_PIL_0 * 255) cv2.imwrite( '/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png', target_PIL * 255) for i in range(self.n): photometric_losses[i].append(photometric_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors_warped[i]) # If using automask if self.automask_loss: # Calculate and store unwarped image loss ref_images = match_scales(ref_image, inv_depths, self.n) unwarped_image_loss = self.calc_photometric_loss( ref_images, images) for i in range(self.n): photometric_losses[i].append(unwarped_image_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors[j][i]) # Calculate reduced photometric loss loss = self.nonzero_reduce_photometric_loss(photometric_losses) # Include smoothness loss if requested if self.smooth_loss_weight > 0.0: loss += self.calc_smoothness_loss( [a * b for a, b in zip(inv_depths, ego_mask_tensors)], [a * b for a, b in zip(images, ego_mask_tensors)]) # Return losses and metrics return { 'loss': loss.unsqueeze(0), 'metrics': self.metrics, }
mode='bilinear', padding_mode='zeros', align_corners=True) warped_right_front_PIL = torch.transpose(warped_right_front.unsqueeze(4), 1, 4).squeeze().cpu().numpy() cv2.imwrite('/home/users/vbelissen/test' + tt + '_right_front.png', warped_right_front_PIL[:, :, ::-1]) simulated_depth_right_to_front = funct.grid_sample(simulated_depth_right, ref_coords_right, mode='bilinear', padding_mode='zeros', align_corners=True) viz_pred_inv_depth_front = viz_inv_depth(depth2inv(simulated_depth)[0], normalizer=1.0) * 255 viz_pred_inv_depth_right = viz_inv_depth(depth2inv(simulated_depth_right)[0], normalizer=1.0) * 255 viz_pred_inv_depth_right_to_front = viz_inv_depth( depth2inv(simulated_depth_right_to_front)[0], normalizer=1.0) * 255 world_points_right_in_front_coords = cam_front.Tcw @ world_points_right simulated_depth_right_in_front_coords = torch.norm( world_points_right_in_front_coords, dim=1, keepdim=True) simulated_depth_right_in_front_coords[~not_masked_right] = 0 simulated_depth_right_to_front_in_front_coords = funct.grid_sample( simulated_depth_right_in_front_coords, ref_coords_right, mode='bilinear', padding_mode='zeros',
def warp_ref(self, inv_depths, ref_image, ref_tensor, ref_inv_depths, path_to_theta_lut, path_to_ego_mask, poly_coeffs, principal_point, scale_factors, ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point, ref_scale_factors, same_timestamp_as_origin, pose_matrix_context, pose, warp_ref_depth, allow_context_rotation): """ Warps a reference image to produce a reconstruction of the original one. Parameters ---------- inv_depths : torch.Tensor [B,1,H,W] Inverse depth map of the original image ref_image : torch.Tensor [B,3,H,W] Reference RGB image K : torch.Tensor [B,3,3] Original camera intrinsics ref_K : torch.Tensor [B,3,3] Reference camera intrinsics pose : Pose Original -> Reference camera transformation Returns ------- ref_warped : torch.Tensor [B,3,H,W] Warped reference image (reconstructing the original one) """ B, _, H, W = ref_image.shape device = ref_image.get_device() # Generate cameras for all scales cams, ref_cams = [], [] if allow_context_rotation: pose_matrix = torch.zeros(B, 4, 4) for b in range(B): if same_timestamp_as_origin[b]: pose_matrix[b, :3, 3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, 3] pose_matrix[b, 3, 3] = 1 pose_matrix[b, :3, :3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, :3] else: pose_matrix[b, :, :] = pose.mat[b, :, :] else: for b in range(B): if same_timestamp_as_origin[b]: pose.mat[b, :, :] = pose_matrix_context[b, :, :] for i in range(self.n): _, _, DH, DW = inv_depths[i].shape scale_factor = DW / float(W) cams.append(CameraFisheye(path_to_theta_lut=path_to_theta_lut, path_to_ego_mask=path_to_ego_mask, poly_coeffs=poly_coeffs.float(), principal_point=principal_point.float(), scale_factors=scale_factors.float()).scaled(scale_factor).to(device)) ref_cams.append(CameraFisheye(path_to_theta_lut=ref_path_to_theta_lut, path_to_ego_mask=ref_path_to_ego_mask, poly_coeffs=ref_poly_coeffs.float(), principal_point=ref_principal_point.float(), scale_factors=ref_scale_factors.float(), Tcw=pose_matrix if allow_context_rotation else pose).scaled(scale_factor).to(device)) # View synthesis depths = [inv2depth(inv_depths[i]) for i in range(self.n)] ref_images = match_scales(ref_image, inv_depths, self.n) if warp_ref_depth: ref_depths = [inv2depth(ref_inv_depths[i]) for i in range(self.n)] ref_warped = [] depths_wrt_ref_cam = [] ref_depths_warped = [] for i in range(self.n): view_i, depth_wrt_ref_cam_i, ref_depth_warped_i \ = view_depth_synthesis2(ref_images[i], depths[i], ref_depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode) ref_warped.append(view_i) depths_wrt_ref_cam.append(depth_wrt_ref_cam_i) ref_depths_warped.append(ref_depth_warped_i) inv_depths_wrt_ref_cam = [depth2inv(depths_wrt_ref_cam[i]) for i in range(self.n)] ref_inv_depths_warped = [depth2inv(ref_depths_warped[i]) for i in range(self.n)] else: ref_warped = [view_synthesis(ref_images[i], depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode) for i in range(self.n)] inv_depths_wrt_ref_cam = None ref_inv_depths_warped = None ref_tensors = match_scales(ref_tensor, inv_depths, self.n, mode='nearest', align_corners=None) ref_tensors_warped = [view_synthesis( ref_tensors[i], depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode, mode='nearest', align_corners=None) for i in range(self.n)] # Return warped reference image return ref_warped, ref_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped
def infer_optimal_calib(input_files, model_wrappers, image_shape): """ Process a list of input files to infer correction in extrinsic calibration. Files should all correspond to the same car. Number of cameras is assumed to be 4 or 5. Parameters ---------- input_file : list (number of cameras) of lists (number of files) of str Image file model_wrappers : nn.Module Model wrappers used for inference image_shape : Image shape Input image shape """ N_files = len(input_files[0]) N_cams = len(input_files) image_area = image_shape[0] * image_shape[1] camera_context_pairs = CAMERA_CONTEXT_PAIRS[N_cams] # Rotation will be optimized if not all cams are frozen optimize_rotation = (args.frozen_cams_rot != [i for i in range(N_cams)]) # Rotation will be optimized if not all cams are frozen optimize_translation = (args.frozen_cams_trans != [i for i in range(N_cams)]) calib_data = {} for i_cam in range(N_cams): base_folder_str = get_base_folder(input_files[i_cam][0]) split_type_str = get_split_type(input_files[i_cam][0]) seq_name_str = get_sequence_name(input_files[i_cam][0]) camera_str = get_camera_name(input_files[i_cam][0]) calib_data[camera_str] = read_raw_calib_files_camera_valeo_with_suffix( base_folder_str, split_type_str, seq_name_str, camera_str, args.calibrations_suffix) cams = [] cams_untouched = [] not_masked = [] # Assume all images are from the same sequence (thus same cameras) for i_cam in range(N_cams): path_to_ego_mask = get_path_to_ego_mask(input_files[i_cam][0]) poly_coeffs, principal_point, scale_factors, K, k, p = get_full_intrinsics( input_files[i_cam][0], calib_data) poly_coeffs_untouched = torch.from_numpy(poly_coeffs).unsqueeze(0) principal_point_untouched = torch.from_numpy( principal_point).unsqueeze(0) scale_factors_untouched = torch.from_numpy(scale_factors).unsqueeze(0) K_untouched = torch.from_numpy(K).unsqueeze(0) k_untouched = torch.from_numpy(k).unsqueeze(0) p_untouched = torch.from_numpy(p).unsqueeze(0) pose_matrix_untouched = torch.from_numpy( get_extrinsics_pose_matrix(input_files[i_cam][0], calib_data)).unsqueeze(0) pose_tensor_untouched = Pose(pose_matrix_untouched) camera_type_untouched = get_camera_type(input_files[i_cam][0], calib_data) camera_type_int_untouched = torch.tensor( [get_camera_type_int(camera_type_untouched)]) cams.append( CameraMultifocal(poly_coeffs=poly_coeffs_untouched.float(), principal_point=principal_point_untouched.float(), scale_factors=scale_factors_untouched.float(), K=K_untouched.float(), k1=k_untouched[:, 0].float(), k2=k_untouched[:, 1].float(), k3=k_untouched[:, 2].float(), p1=p_untouched[:, 0].float(), p2=p_untouched[:, 1].float(), camera_type=camera_type_int_untouched, Tcw=pose_tensor_untouched)) cams_untouched.append( CameraMultifocal(poly_coeffs=poly_coeffs_untouched.float(), principal_point=principal_point_untouched.float(), scale_factors=scale_factors_untouched.float(), K=K_untouched.float(), k1=k_untouched[:, 0].float(), k2=k_untouched[:, 1].float(), k3=k_untouched[:, 2].float(), p1=p_untouched[:, 0].float(), p2=p_untouched[:, 1].float(), camera_type=camera_type_int_untouched, Tcw=pose_tensor_untouched)) if torch.cuda.is_available(): cams[i_cam] = cams[i_cam].to('cuda:{}'.format(rank())) cams_untouched[i_cam] = cams_untouched[i_cam].to('cuda:{}'.format( rank())) ego_mask = np.load(path_to_ego_mask) not_masked.append( torch.from_numpy(ego_mask.astype(float)).cuda().float()) # Learning variables extra_trans_m = [ torch.autograd.Variable(torch.zeros(3).cuda(), requires_grad=True) for _ in range(N_cams) ] extra_rot_deg = [ torch.autograd.Variable(torch.zeros(3).cuda(), requires_grad=True) for _ in range(N_cams) ] # Constraints: translation frozen_cams_trans = args.frozen_cams_trans if frozen_cams_trans is not None: for i_cam in frozen_cams_trans: extra_trans_m[i_cam].requires_grad = False # Constraints: rotation frozen_cams_rot = args.frozen_cams_rot if frozen_cams_rot is not None: for i_cam in frozen_cams_rot: extra_rot_deg[i_cam].requires_grad = False # Parameters from argument parser save_pictures = args.save_pictures n_epochs = args.n_epochs learning_rate = args.lr step_size = args.scheduler_step_size gamma = args.scheduler_gamma # Table of loss loss_tab = np.zeros(n_epochs) # Table of extra rotation values extra_rot_values_tab = np.zeros((N_cams * 3, N_files * n_epochs)) # Table of extra translation values extra_trans_values_tab = np.zeros((N_cams * 3, N_files * n_epochs)) # Optimizer optimizer = optim.Adam(extra_trans_m + extra_rot_deg, lr=learning_rate) # Scheduler scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma) # Regularization weights regul_weight_trans = torch.tensor(args.regul_weight_trans).cuda() regul_weight_rot = torch.tensor(args.regul_weight_rot).cuda() regul_weight_overlap = torch.tensor(args.regul_weight_overlap).cuda() # Loop on the number of epochs count = 0 for epoch in range(n_epochs): print('Epoch ' + str(epoch) + '/' + str(n_epochs)) # Initialize loss loss_sum = 0 # Loop on the number of files for i_file in range(N_files): print('') # Filename for camera 0 base_0, ext_0 = os.path.splitext( os.path.basename(input_files[0][i_file])) print(base_0) # Initialize list of tensors: images, predicted inverse depths and predicted depths images, pred_inv_depths, pred_depths = [], [], [] input_depth_files, has_gt_depth, gt_depth, gt_inv_depth = [], [], [], [] nb_gt_depths = 0 # Reset camera poses Twc CameraMultifocal.Twc.fget.cache_clear() # Loop on cams and predict depth for i_cam in range(N_cams): images.append( load_image(input_files[i_cam][i_file]).convert('RGB')) images[i_cam] = resize_image(images[i_cam], image_shape) images[i_cam] = to_tensor(images[i_cam]).unsqueeze(0) if torch.cuda.is_available(): images[i_cam] = images[i_cam].to('cuda:{}'.format(rank())) with torch.no_grad(): pred_inv_depths.append(model_wrappers[i_cam].depth( images[i_cam])) pred_depths.append(inv2depth(pred_inv_depths[i_cam])) if args.use_lidar: input_depth_files.append( get_depth_file(input_files[i_cam][i_file], args.depth_suffix)) has_gt_depth.append( os.path.exists(input_depth_files[i_cam])) if has_gt_depth[i_cam]: nb_gt_depths += 1 gt_depth.append( np.load(input_depth_files[i_cam]) ['velodyne_depth'].astype(np.float32)) gt_depth[i_cam] = torch.from_numpy( gt_depth[i_cam]).unsqueeze(0).unsqueeze(0) gt_inv_depth.append(depth2inv(gt_depth[i_cam])) if torch.cuda.is_available(): gt_depth[i_cam] = gt_depth[i_cam].to( 'cuda:{}'.format(rank())) gt_inv_depth[i_cam] = gt_inv_depth[i_cam].to( 'cuda:{}'.format(rank())) else: gt_depth.append(0) gt_inv_depth.append(0) # Apply correction on cams pose_matrix = get_extrinsics_pose_matrix_extra_trans_rot_torch( input_files[i_cam][i_file], calib_data, extra_trans_m[i_cam], extra_rot_deg[i_cam]).unsqueeze(0) pose_tensor = Pose(pose_matrix).to('cuda:{}'.format(rank())) cams[i_cam].Tcw = pose_tensor # Define a loss function between 2 images def photo_loss_2imgs(i_cam1, i_cam2, save_pictures): # Computes the photometric loss between 2 images of adjacent cameras # It reconstructs each image from the adjacent one, applying correction in rotation and translation # Reconstruct 3D points for each cam world_points1 = cams[i_cam1].reconstruct(pred_depths[i_cam1], frame='w') world_points2 = cams[i_cam2].reconstruct(pred_depths[i_cam2], frame='w') # Get coordinates of projected points on other cam ref_coords1to2 = cams[i_cam2].project(world_points1, frame='w') ref_coords2to1 = cams[i_cam1].project(world_points2, frame='w') # Reconstruct each image from the adjacent camera reconstructedImg2to1 = funct.grid_sample(images[i_cam2] * not_masked[i_cam2], ref_coords1to2, mode='bilinear', padding_mode='zeros', align_corners=True) reconstructedImg1to2 = funct.grid_sample(images[i_cam1] * not_masked[i_cam1], ref_coords2to1, mode='bilinear', padding_mode='zeros', align_corners=True) # Save pictures if requested if save_pictures: # Save original files if first epoch if epoch == 0: cv2.imwrite( args.save_folder + '/cam_' + str(i_cam1) + '_file_' + str(i_file) + '_orig.png', (images[i_cam1][0].permute(1, 2, 0) )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255) cv2.imwrite( args.save_folder + '/cam_' + str(i_cam2) + '_file_' + str(i_file) + '_orig.png', (images[i_cam2][0].permute(1, 2, 0) )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255) # Save reconstructed images cv2.imwrite( args.save_folder + '/epoch_' + str(epoch) + '_file_' + str(i_file) + '_cam_' + str(i_cam1) + '_recons_from_' + str(i_cam2) + '.png', ((reconstructedImg2to1 * not_masked[i_cam1])[0].permute(1, 2, 0) )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255) cv2.imwrite( args.save_folder + '/epoch_' + str(epoch) + '_file_' + str(i_file) + '_cam_' + str(i_cam2) + '_recons_from_' + str(i_cam1) + '.png', ((reconstructedImg1to2 * not_masked[i_cam2])[0].permute(1, 2, 0) )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255) # L1 loss l1_loss_1 = torch.abs(images[i_cam1] * not_masked[i_cam1] - reconstructedImg2to1 * not_masked[i_cam1]) l1_loss_2 = torch.abs(images[i_cam2] * not_masked[i_cam2] - reconstructedImg1to2 * not_masked[i_cam2]) # SSIM loss ssim_loss_weight = 0.85 ssim_loss_1 = SSIM(images[i_cam1] * not_masked[i_cam1], reconstructedImg2to1 * not_masked[i_cam1], C1=1e-4, C2=9e-4, kernel_size=3) ssim_loss_2 = SSIM(images[i_cam2] * not_masked[i_cam2], reconstructedImg1to2 * not_masked[i_cam2], C1=1e-4, C2=9e-4, kernel_size=3) ssim_loss_1 = torch.clamp((1. - ssim_loss_1) / 2., 0., 1.) ssim_loss_2 = torch.clamp((1. - ssim_loss_2) / 2., 0., 1.) # Photometric loss: alpha * ssim + (1 - alpha) * l1 photometric_loss_1 = ssim_loss_weight * ssim_loss_1.mean( 1, True) + (1 - ssim_loss_weight) * l1_loss_1.mean( 1, True) photometric_loss_2 = ssim_loss_weight * ssim_loss_2.mean( 1, True) + (1 - ssim_loss_weight) * l1_loss_2.mean( 1, True) # Compute the number of valid pixels mask1 = (reconstructedImg2to1 * not_masked[i_cam1]).sum( axis=1, keepdim=True) != 0 s1 = mask1.sum().float() mask2 = (reconstructedImg1to2 * not_masked[i_cam2]).sum( axis=1, keepdim=True) != 0 s2 = mask2.sum().float() # Compute the photometric losses weighed by the number of valid pixels loss_1 = (photometric_loss_1 * mask1).sum() / s1 if s1 > 0 else 0 loss_2 = (photometric_loss_2 * mask2).sum() / s2 if s2 > 0 else 0 # The final loss can be regularized to encourage a similar overlap between images if s1 > 0 and s2 > 0: return loss_1 + loss_2 + regul_weight_overlap * image_area * ( 1 / s1 + 1 / s2) else: return 0. def lidar_loss(i_cam1, save_pictures): if args.use_lidar and has_gt_depth[i_cam1]: mask_zeros_lidar = ( gt_depth[i_cam1][0, 0, :, :] == 0).detach() # Ground truth sparse depth maps were generated using the untouched camera extrinsics world_points_gt_oldCalib = cams_untouched[ i_cam1].reconstruct(gt_depth[i_cam1], frame='w') world_points_gt_oldCalib[0, 0, mask_zeros_lidar] = 0. # Get coordinates of projected points on new cam ref_coords = cams[i_cam1].project(world_points_gt_oldCalib, frame='w') ref_coords[0, mask_zeros_lidar, :] = 0. # Reconstruct projected lidar from the new camera reprojected_gt_inv_depth = funct.grid_sample( gt_inv_depth[i_cam1], ref_coords, mode='nearest', padding_mode='zeros', align_corners=True) reprojected_gt_inv_depth[0, 0, mask_zeros_lidar] = 0. mask_reprojected = (reprojected_gt_inv_depth > 0.).detach() if save_pictures: mask_reprojected_numpy = mask_reprojected[ 0, 0, :, :].cpu().numpy() u = np.where(mask_reprojected_numpy)[0] v = np.where(mask_reprojected_numpy)[1] n_lidar = u.size reprojected_gt_depth_numpy = inv2depth( reprojected_gt_inv_depth)[ 0, 0, :, :].detach().cpu().numpy() im = (images[i_cam1][0].permute( 1, 2, 0))[:, :, [2, 1, 0]].detach().cpu().numpy() * 255 dmax = 100. for i_l in range(n_lidar): d = reprojected_gt_depth_numpy[u[i_l], v[i_l]] s = int((8 / d)) + 1 im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s, 0] = np.clip( np.power(d / dmax, .7) * 255, 10, 245) im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s, 1] = np.clip( np.power((dmax - d) / dmax, 4.0) * 255, 10, 245) im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s, 2] = np.clip( np.power(np.abs(2 * (d - .5 * dmax) / dmax), 3.0) * 255, 10, 245) cv2.imwrite( args.save_folder + '/epoch_' + str(epoch) + '_file_' + str(i_file) + '_cam_' + str(i_cam1) + '_lidar.png', im) if mask_reprojected.sum() > 0: return l1_lidar_loss( pred_inv_depths[i_cam1] * not_masked[i_cam1], reprojected_gt_inv_depth * not_masked[i_cam1]) else: return 0. else: return 0. if nb_gt_depths > 0: final_lidar_weight = (N_cams / nb_gt_depths) * args.lidar_weight else: final_lidar_weight = 0. # The final loss consists of summing the photometric loss of all pairs of adjacent cameras # and is regularized to prevent weights from exploding photo_loss = 1.0 * sum([ photo_loss_2imgs(p[0], p[1], save_pictures) for p in camera_context_pairs ]) regul_rot_loss = regul_weight_rot * sum( [(extra_rot_deg[i]**2).sum() for i in range(N_cams)]) regul_trans_loss = regul_weight_trans * sum( [(extra_trans_m[i]**2).sum() for i in range(N_cams)]) lidar_gt_loss = final_lidar_weight * sum( [lidar_loss(i, save_pictures) for i in range(N_cams)]) loss = photo_loss + regul_rot_loss + regul_trans_loss + lidar_gt_loss with torch.no_grad(): extra_rot_deg_before = [] extra_trans_m_before = [] for i_cam in range(N_cams): extra_rot_deg_before.append(extra_rot_deg[i_cam].clone()) extra_trans_m_before.append(extra_trans_m[i_cam].clone()) # Optimization steps optimizer.zero_grad() loss.backward() optimizer.step() with torch.no_grad(): extra_rot_deg_after = [] extra_trans_m_after = [] for i_cam in range(N_cams): extra_rot_deg_after.append(extra_rot_deg[i_cam].clone()) extra_trans_m_after.append(extra_trans_m[i_cam].clone()) rot_change_file = 0. trans_change_file = 0. for i_cam in range(N_cams): rot_change_file += torch.abs( extra_rot_deg_after[i_cam] - extra_rot_deg_before[i_cam]).mean().item() trans_change_file += torch.abs( extra_trans_m_after[i_cam] - extra_trans_m_before[i_cam]).mean().item() rot_change_file /= N_cams trans_change_file /= N_cams print('Average rotation change (deg.): ' + "{:.4f}".format(rot_change_file)) print('Average translation change (m.): ' + "{:.4f}".format(trans_change_file)) # Save correction values and print loss with torch.no_grad(): loss_sum += loss.item() for i_cam in range(N_cams): for j in range(3): if optimize_rotation: extra_rot_values_tab[ 3 * i_cam + j, count] = extra_rot_deg[i_cam][j].item() if optimize_translation: extra_trans_values_tab[ 3 * i_cam + j, count] = extra_trans_m[i_cam][j].item() print('Loss: ' + "{:.3f}".format(loss.item()) \ + ' (photometric: ' + "{:.3f}".format(photo_loss.item()) \ + ', rotation reg.: ' + "{:.4f}".format(regul_rot_loss.item()) \ + ', translation reg.: ' + "{:.4f}".format(regul_trans_loss.item()) + ', lidar: ' + "{:.3f}".format(lidar_gt_loss) +')') if nb_gt_depths > 0: print('Number of ground truth lidar maps: ' + str(nb_gt_depths)) count += 1 # Update scheduler print('Epoch:', epoch, 'LR:', scheduler.get_lr()) scheduler.step() with torch.no_grad(): print('End of epoch') if optimize_translation: print('New translation correction values: ') print(extra_trans_m) if optimize_rotation: print('New rotation correction values: ') print(extra_rot_deg) print('Average rotation change in epoch:') loss_tab[epoch] = loss_sum / N_files # Plot/save loss if requested plt.figure() plt.plot(loss_tab) if args.show_plots: plt.show() if args.save_plots: plt.savefig( os.path.join(args.save_folder, get_sequence_name(input_files[0][0]) + '_loss.png')) # Plot/save correction values if requested if optimize_rotation: plt.figure() for j in range(N_cams * 3): plt.plot(extra_rot_values_tab[j]) if args.show_plots: plt.show() if args.save_plots: plt.savefig( os.path.join( args.save_folder, get_sequence_name(input_files[0][0]) + '_extra_rot.png')) if optimize_translation: plt.figure() for j in range(N_cams * 3): plt.plot(extra_trans_values_tab[j]) if args.show_plots: plt.show() if args.save_plots: plt.savefig( os.path.join( args.save_folder, get_sequence_name(input_files[0][0]) + '_extra_trans.png')) # Save correction values table if requested if args.save_rot_tab: np.save( os.path.join(args.save_folder, get_sequence_name(input_files[0][0]) + '_rot_tab.npy'), extra_rot_values_tab)