def calc_sparse_depth_loss(self, sparse_depth, p_world_hat, mask_pred, reduction_method, loss={}, eval_mode=False): ''' Calculates the sparse depth loss. Args: sparse_depth (dict): dictionary for sparse depth loss calculation p_world_hat (tensor): predicted world points mask_pred (tensor): mask for predicted values reduction_method (string): how to reduce the loss tensor loss (dict): loss dictionary eval_mode (bool): whether to use eval mode ''' if self.lambda_sparse_depth != 0: p_world = sparse_depth['p_world'] depth_gt = sparse_depth['depth_gt'] camera_mat = sparse_depth['camera_mat'] world_mat = sparse_depth['world_mat'] scale_mat = sparse_depth['scale_mat'] # Shortscuts batch_size, n_points, _ = p_world.shape if self.depth_loss_on_world_points: loss_sparse_depth = losses.l2_loss( p_world_hat[mask_pred], p_world[mask_pred], reduction_method) * self.lambda_sparse_depth / batch_size else: d_pred_cam = transform_to_camera_space(p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1] loss_sparse_depth = losses.l1_loss( d_pred_cam[mask_pred], depth_gt[mask_pred], reduction_method, feat_dim=False) * \ self.lambda_sparse_depth / batch_size if eval_mode: if self.depth_loss_on_world_points: loss_sparse_depth_val = losses.l2_loss( p_world_hat[mask_pred], p_world[mask_pred], 'mean') * \ self.lambda_sparse_depth else: d_pred_cam = transform_to_camera_space( p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1] loss_sparse_depth_val = losses.l1_loss( d_pred_cam[mask_pred], depth_gt[mask_pred], 'mean', feat_dim=False) * self.lambda_sparse_depth loss['loss_sparse_depth_val'] = loss_sparse_depth_val loss['loss'] += loss_sparse_depth loss['loss_sparse_depth'] = loss_sparse_depth
def calc_depth_loss(self, mask_depth, depth_img, pixels, camera_mat, world_mat, scale_mat, p_world_hat, reduction_method, loss={}, eval_mode=False): ''' Calculates the depth loss. Args: mask_depth (tensor): mask for depth loss depth_img (tensor): depth image pixels (tensor): sampled pixels in range [-1, 1] camera_mat (tensor): camera matrix world_mat (tensor): world matrix scale_mat (tensor): scale matrix p_world_hat (tensor): predicted world points reduction_method (string): how to reduce the loss tensor loss (dict): loss dictionary eval_mode (bool): whether to use eval mode ''' if self.lambda_depth != 0 and mask_depth.sum() > 0: batch_size, n_pts, _ = p_world_hat.shape loss_depth_val = torch.tensor(10) # For depth values, we have to check again if all values are valid # as we potentially train with sparse depth maps depth_gt, mask_gt_depth = get_tensor_values( depth_img, pixels, squeeze_channel_dim=True, with_mask=True) mask_depth &= mask_gt_depth if self.depth_loss_on_world_points: # Applying L2 loss on world points results in the same as # applying L1 on the depth values with scaling (see Sup. Mat.) p_world = transform_to_world( pixels, depth_gt.unsqueeze(-1), camera_mat, world_mat, scale_mat) loss_depth = losses.l2_loss( p_world_hat[mask_depth], p_world[mask_depth], reduction_method) * self.lambda_depth / batch_size if eval_mode: loss_depth_val = losses.l2_loss( p_world_hat[mask_depth], p_world[mask_depth], 'mean') * self.lambda_depth else: d_pred = transform_to_camera_space( p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1] loss_depth = losses.l1_loss( d_pred[mask_depth], depth_gt[mask_depth], reduction_method, feat_dim=False) * \ self.lambda_depth / batch_size if eval_mode: loss_depth_val = losses.l1_loss( d_pred[mask_depth], depth_gt[mask_depth], 'mean', feat_dim=False) * self.lambda_depth loss['loss'] += loss_depth loss['loss_depth'] = loss_depth if eval_mode: loss['loss_depth_eval'] = loss_depth_val
def calc_photoconsistency_loss(self, mask_rgb, rgb_pred, img, pixels, reduction_method, loss, patch_size, eval_mode=False): ''' Calculates the photo-consistency loss. Args: mask_rgb (tensor): mask for photo-consistency loss rgb_pred (tensor): predicted rgb color values img (tensor): GT image pixels (tensor): sampled pixels in range [-1, 1] reduction_method (string): how to reduce the loss tensor loss (dict): loss dictionary patch_size (int): size of sampled patch eval_mode (bool): whether to use eval mode ''' if self.lambda_rgb != 0 and mask_rgb.sum() > 0: batch_size, n_pts, _ = rgb_pred.shape loss_rgb_eval = torch.tensor(3) # Get GT RGB values rgb_gt = get_tensor_values(img, pixels) # 3.1) Calculate RGB Loss loss_rgb = losses.l1_loss( rgb_pred[mask_rgb], rgb_gt[mask_rgb], reduction_method) * self.lambda_rgb / batch_size loss['loss'] += loss_rgb loss['loss_rgb'] = loss_rgb if eval_mode: loss_rgb_eval = losses.l1_loss( rgb_pred[mask_rgb], rgb_gt[mask_rgb], 'mean') * \ self.lambda_rgb # 3.2) Image Gradient loss if self.lambda_image_gradients != 0: assert (patch_size > 1) loss_grad = losses.image_gradient_loss( rgb_pred, rgb_gt, mask_rgb, patch_size, reduction_method) * \ self.lambda_image_gradients / batch_size loss['loss'] += loss_grad loss['loss_image_gradient'] = loss_grad if eval_mode: loss['loss_rgb_eval'] = loss_rgb_eval