Ejemplo n.º 1
0
    def calc_sparse_depth_loss(self,
                               sparse_depth,
                               p_world_hat,
                               mask_pred,
                               reduction_method,
                               loss={},
                               eval_mode=False):
        ''' Calculates the sparse depth loss.

        Args:
            sparse_depth (dict): dictionary for sparse depth loss calculation
            p_world_hat (tensor): predicted world points
            mask_pred (tensor): mask for predicted values
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_sparse_depth != 0:
            p_world = sparse_depth['p_world']
            depth_gt = sparse_depth['depth_gt']
            camera_mat = sparse_depth['camera_mat']
            world_mat = sparse_depth['world_mat']
            scale_mat = sparse_depth['scale_mat']

            # Shortscuts
            batch_size, n_points, _ = p_world.shape
            if self.depth_loss_on_world_points:
                loss_sparse_depth = losses.l2_loss(
                    p_world_hat[mask_pred], p_world[mask_pred],
                    reduction_method) * self.lambda_sparse_depth / batch_size
            else:
                d_pred_cam = transform_to_camera_space(p_world_hat, camera_mat,
                                                       world_mat,
                                                       scale_mat)[:, :, -1]
                loss_sparse_depth = losses.l1_loss(
                    d_pred_cam[mask_pred], depth_gt[mask_pred],
                    reduction_method, feat_dim=False) * \
                    self.lambda_sparse_depth / batch_size

            if eval_mode:
                if self.depth_loss_on_world_points:
                    loss_sparse_depth_val = losses.l2_loss(
                        p_world_hat[mask_pred], p_world[mask_pred], 'mean') * \
                        self.lambda_sparse_depth
                else:
                    d_pred_cam = transform_to_camera_space(
                        p_world_hat, camera_mat, world_mat, scale_mat)[:, :,
                                                                       -1]
                    loss_sparse_depth_val = losses.l1_loss(
                        d_pred_cam[mask_pred],
                        depth_gt[mask_pred],
                        'mean',
                        feat_dim=False) * self.lambda_sparse_depth
                loss['loss_sparse_depth_val'] = loss_sparse_depth_val

            loss['loss'] += loss_sparse_depth
            loss['loss_sparse_depth'] = loss_sparse_depth
    def calc_depth_loss(self, mask_depth, depth_img, pixels,
                        camera_mat, world_mat, scale_mat, p_world_hat,
                        reduction_method, loss={}, eval_mode=False):
        ''' Calculates the depth loss.

        Args:
            mask_depth (tensor): mask for depth loss
            depth_img (tensor): depth image
            pixels (tensor): sampled pixels in range [-1, 1]
            camera_mat (tensor): camera matrix
            world_mat (tensor): world matrix
            scale_mat (tensor): scale matrix
            p_world_hat (tensor): predicted world points
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_depth != 0 and mask_depth.sum() > 0:
            batch_size, n_pts, _ = p_world_hat.shape
            loss_depth_val = torch.tensor(10)
            # For depth values, we have to check again if all values are valid
            # as we potentially train with sparse depth maps
            depth_gt, mask_gt_depth = get_tensor_values(
                depth_img, pixels, squeeze_channel_dim=True, with_mask=True)
            mask_depth &= mask_gt_depth
            if self.depth_loss_on_world_points:
                # Applying L2 loss on world points results in the same as
                # applying L1 on the depth values with scaling (see Sup. Mat.)
                p_world = transform_to_world(
                    pixels, depth_gt.unsqueeze(-1), camera_mat, world_mat,
                    scale_mat)
                loss_depth = losses.l2_loss(
                    p_world_hat[mask_depth], p_world[mask_depth],
                    reduction_method) * self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l2_loss(
                        p_world_hat[mask_depth], p_world[mask_depth],
                        'mean') * self.lambda_depth
            else:
                d_pred = transform_to_camera_space(
                    p_world_hat, camera_mat, world_mat, scale_mat)[:, :, -1]
                loss_depth = losses.l1_loss(
                    d_pred[mask_depth], depth_gt[mask_depth],
                    reduction_method, feat_dim=False) * \
                    self.lambda_depth / batch_size
                if eval_mode:
                    loss_depth_val = losses.l1_loss(
                        d_pred[mask_depth], depth_gt[mask_depth],
                        'mean', feat_dim=False) * self.lambda_depth

            loss['loss'] += loss_depth
            loss['loss_depth'] = loss_depth
            if eval_mode:
                loss['loss_depth_eval'] = loss_depth_val
Ejemplo n.º 3
0
    def calc_photoconsistency_loss(self,
                                   mask_rgb,
                                   rgb_pred,
                                   img,
                                   pixels,
                                   reduction_method,
                                   loss,
                                   patch_size,
                                   eval_mode=False):
        ''' Calculates the photo-consistency loss.

        Args:
            mask_rgb (tensor): mask for photo-consistency loss
            rgb_pred (tensor): predicted rgb color values
            img (tensor): GT image
            pixels (tensor): sampled pixels in range [-1, 1]
            reduction_method (string): how to reduce the loss tensor
            loss (dict): loss dictionary
            patch_size (int): size of sampled patch
            eval_mode (bool): whether to use eval mode
        '''
        if self.lambda_rgb != 0 and mask_rgb.sum() > 0:
            batch_size, n_pts, _ = rgb_pred.shape
            loss_rgb_eval = torch.tensor(3)
            # Get GT RGB values
            rgb_gt = get_tensor_values(img, pixels)

            # 3.1) Calculate RGB Loss
            loss_rgb = losses.l1_loss(
                rgb_pred[mask_rgb], rgb_gt[mask_rgb],
                reduction_method) * self.lambda_rgb / batch_size
            loss['loss'] += loss_rgb
            loss['loss_rgb'] = loss_rgb
            if eval_mode:
                loss_rgb_eval = losses.l1_loss(
                    rgb_pred[mask_rgb], rgb_gt[mask_rgb], 'mean') * \
                    self.lambda_rgb

            # 3.2) Image Gradient loss
            if self.lambda_image_gradients != 0:
                assert (patch_size > 1)
                loss_grad = losses.image_gradient_loss(
                    rgb_pred, rgb_gt, mask_rgb, patch_size,
                    reduction_method) * \
                    self.lambda_image_gradients / batch_size
                loss['loss'] += loss_grad
                loss['loss_image_gradient'] = loss_grad
            if eval_mode:
                loss['loss_rgb_eval'] = loss_rgb_eval