Esempio n. 1
0
    def _postprocess(self,
                     net_output,
                     original_size,
                     mode='bilinear',
                     align_corners=False):
        pred_inv_depth_resized = interpolate_image(
            net_output, (original_size[0], original_size[1]), mode,
            align_corners)
        depth_img = self.write_depth(self.inv2depth(pred_inv_depth_resized))

        return depth_img
Esempio n. 2
0
def scale_depth(pred, gt, scale_fn):
    """
    Match depth maps to ground-truth resolution

    Parameters
    ----------
    pred : torch.Tensor
        Predicted depth maps [B,1,w,h]
    gt : torch.tensor
        Ground-truth depth maps [B,1,H,W]
    scale_fn : str
        How to scale output to GT resolution
            Resize: Nearest neighbors interpolation
            top-center: Pad the top of the image and left-right corners with zeros

    Returns
    -------
    pred : torch.tensor
        Uncropped predicted depth maps [B,1,H,W]
    """
    if scale_fn == 'resize':
        # Resize depth map to GT resolution
        return interpolate_image(pred,
                                 gt.shape,
                                 mode='bilinear',
                                 align_corners=True)
    else:
        # Create empty depth map with GT resolution
        pred_uncropped = torch.zeros(gt.shape,
                                     dtype=pred.dtype,
                                     device=pred.device)
        # Uncrop top vertically and center horizontally
        if scale_fn == 'top-center':
            top, left = gt.shape[2] - pred.shape[2], (gt.shape[3] -
                                                      pred.shape[3]) // 2
            pred_uncropped[:, :, top:(top + pred.shape[2]),
                           left:(left + pred.shape[3])] = pred
        else:
            raise NotImplementedError(
                'Depth scale function {} not implemented.'.format(scale_fn))
        # Return uncropped depth map
        return pred_uncropped
    def forward(self,
                image,
                context,
                inv_depths,
                poses,
                path_to_ego_mask,
                path_to_ego_mask_context,
                K,
                ref_K,
                extrinsics,
                ref_extrinsics,
                context_type,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        inv_depths2 = []
        for k in range(len(inv_depths)):

            Btmp, C, H, W = inv_depths[k].shape
            depths2 = torch.zeros_like(inv_depths[k])
            for i in range(H):
                for j in range(W):
                    depths2[
                        0, 0, i,
                        j] = 2.0  #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20

            inv_depths2.append(depth2inv(depths2))

        if is_list(context_type[0][0]):
            n_context = len(context_type[0])
        else:
            n_context = len(context_type)

        #n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        if is_list(path_to_ego_mask[0]):
            H_full, W_full = np.load(path_to_ego_mask[0][0]).shape
        else:
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                if is_list(path_to_ego_mask[b]):
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b][0])).float()
                else:
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    if is_list(path_to_ego_mask_context[0][0]):
                        paths_context_ego = [
                            p[i_context][0] for p in path_to_ego_mask_context
                        ]
                    else:
                        paths_context_ego = path_to_ego_mask_context[i_context]
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(paths_context_ego[b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            Btmp, C, H, W = images[i].shape
            if W < W_full:
                #inv_scale_factor = int(W_full / W)
                #print(W_full / W)
                #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(Btmp, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_context in range(n_context):
                    #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            _, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                ref_ego_mask_tensor[i_context] = interpolate_image(
                    ref_ego_mask_tensor[i_context],
                    shape=(Btmp, 1, H, W),
                    mode='nearest',
                    align_corners=None)

        print(ref_extrinsics)
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            print(ref_extrinsics[j])
            ref_context_type = [c[j][0] for c in context_type] if is_list(
                context_type[0][0]) else context_type[j]
            print(ref_context_type)
            print(pose.mat)
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image(
                inv_depths2, ref_image, ref_ego_mask_tensor[j], K,
                ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type)
            print(pose.mat)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            tt = str(int(time.time() % 10000))
            for i in range(self.n):
                B, C, H, W = images[i].shape
                for b in range(B):
                    orig_PIL_0 = torch.transpose(
                        (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    orig_PIL = torch.transpose(
                        (ref_image[b, :, :, :] *
                         ref_ego_mask_tensor[j][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    warped_PIL_0 = torch.transpose(
                        (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4),
                        1, 4).squeeze().detach().cpu().numpy()
                    warped_PIL = torch.transpose(
                        (ref_warped[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL_0 = torch.transpose(
                        (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL = torch.transpose(
                        (images[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()

                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png',
                        orig_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png',
                        orig_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png',
                        warped_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png',
                        warped_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png',
                        target_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png',
                        target_PIL * 255)

            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ref_ego_mask_tensors[j][i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                inv_depths,
                gt_inv_depth,
                path_to_ego_mask,
                return_logs=False,
                progress=0.0):
        """
        Calculates training supervised loss.

        Parameters
        ----------
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        gt_inv_depth : torch.Tensor [B,1,H,W]
            Ground-truth depth map for the original image
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Match predicted scales for ground-truth
        gt_inv_depths = match_scales(gt_inv_depth,
                                     inv_depths,
                                     self.n,
                                     mode='nearest',
                                     align_corners=None)

        if self.mask_ego:
            device = gt_inv_depth.get_device()
            B = len(path_to_ego_mask)
            H_full = 800
            W_full = 1280
            ego_mask_tensor = torch.zeros(B, 1, H_full, W_full).to(device)
            for b in range(B):
                ego_mask_tensor[b, 0] = torch.from_numpy(
                    np.load(path_to_ego_mask[b])).float()
            ego_mask_tensors = []  # = torch.zeros(B, 1, 800, 1280)
            for i in range(self.n):
                B, C, H, W = inv_depths[i].shape
                if W < W_full:
                    ego_mask_tensors.append(
                        interpolate_image(ego_mask_tensor,
                                          shape=(B, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)

        # Calculate and store supervised loss
        if self.mask_ego:
            loss = self.calculate_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(gt_inv_depths, ego_mask_tensors)])
        else:
            loss = self.calculate_loss(inv_depths, gt_inv_depths)
        self.add_metric('supervised_loss', loss)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
Esempio n. 5
0
def compute_depth_metrics(config, gt, pred, use_gt_scale=True):
    """
    Compute depth metrics from predicted and ground-truth depth maps

    Parameters
    ----------
    config : CfgNode
        Metrics parameters
    gt : torch.Tensor [B,1,H,W]
        Ground-truth depth map
    pred : torch.Tensor [B,1,H,W]
        Predicted depth map
    use_gt_scale : bool
        True if ground-truth median-scaling is to be used

    Returns
    -------
    metrics : torch.Tensor [7]
        Depth metrics (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3)
    """
    crop = config.crop == 'garg'

    # Initialize variables
    batch_size, _, gt_height, gt_width = gt.shape
    abs_diff = abs_rel = sq_rel = rmse = rmse_log = a1 = a2 = a3 = 0.0
    # Interpolate predicted depth to ground-truth resolution
    pred = interpolate_image(pred, gt.shape, mode='bilinear', align_corners=True)
    # If using crop
    if crop:
        crop_mask = torch.zeros(gt.shape[-2:]).byte().type_as(gt)
        y1, y2 = int(0.40810811 * gt_height), int(0.99189189 * gt_height)
        x1, x2 = int(0.03594771 * gt_width), int(0.96405229 * gt_width)
        crop_mask[y1:y2, x1:x2] = 1
    # For each depth map
    for pred_i, gt_i in zip(pred, gt):
        gt_i, pred_i = torch.squeeze(gt_i), torch.squeeze(pred_i)
        # Keep valid pixels (min/max depth and crop)
        valid = (gt_i > config.min_depth) & (gt_i < config.max_depth)
        valid = valid & crop_mask.bool() if crop else valid
        # Stop if there are no remaining valid pixels
        if valid.sum() == 0:
            continue
        # Keep only valid pixels
        gt_i, pred_i = gt_i[valid], pred_i[valid]
        # Ground-truth median scaling if needed
        if use_gt_scale:
            pred_i = pred_i * torch.median(gt_i) / torch.median(pred_i)
        # Clamp predicted depth values to min/max values
        pred_i = pred_i.clamp(config.min_depth, config.max_depth)

        # Calculate depth metrics

        thresh = torch.max((gt_i / pred_i), (pred_i / gt_i))
        a1 += (thresh < 1.25     ).float().mean()
        a2 += (thresh < 1.25 ** 2).float().mean()
        a3 += (thresh < 1.25 ** 3).float().mean()

        diff_i = gt_i - pred_i
        abs_diff += torch.mean(torch.abs(diff_i))
        abs_rel += torch.mean(torch.abs(diff_i) / gt_i)
        sq_rel += torch.mean(diff_i ** 2 / gt_i)
        rmse += torch.sqrt(torch.mean(diff_i ** 2))
        rmse_log += torch.sqrt(torch.mean((torch.log(gt_i) -
                                           torch.log(pred_i)) ** 2))
    # Return average values for each metric
    return torch.tensor([metric / batch_size for metric in
        [abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3]]).type_as(gt)
    def forward(self,
                image,
                ref_images_temporal_context,
                ref_images_geometric_context,
                ref_images_geometric_context_temporal_context,
                inv_depths,
                poses_temporal_context,
                poses_geometric_context,
                poses_geometric_context_temporal_context,
                camera_type,
                intrinsics_poly_coeffs,
                intrinsics_principal_point,
                intrinsics_scale_factors,
                intrinsics_K,
                intrinsics_k,
                intrinsics_p,
                path_to_ego_mask,
                camera_type_geometric_context,
                intrinsics_poly_coeffs_geometric_context,
                intrinsics_principal_point_geometric_context,
                intrinsics_scale_factors_geometric_context,
                intrinsics_K_geometric_context,
                intrinsics_k_geometric_context,
                intrinsics_p_geometric_context,
                path_to_ego_mask_geometric_context,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_temporal_context = len(ref_images_temporal_context)
        n_geometric_context = len(ref_images_geometric_context)
        assert len(ref_images_geometric_context_temporal_context
                   ) == n_temporal_context * n_geometric_context
        B = len(path_to_ego_mask)

        device = image.get_device()

        # getting ego masks for target and source cameras
        # fullsize mask
        H_full = 800
        W_full = 1280
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor_geometric_context = []
        for i_geometric_context in range(n_geometric_context):
            ref_ego_mask_tensor_geometric_context.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            ego_mask_tensor[b, 0] = torch.from_numpy(
                np.load(path_to_ego_mask[b])).float()
            for i_geometric_context in range(n_geometric_context):
                if camera_type_geometric_context[b, i_geometric_context] != 2:
                    ref_ego_mask_tensor_geometric_context[i_geometric_context][b, 0] = \
                        torch.from_numpy(np.load(path_to_ego_mask_geometric_context[i_geometric_context][b])).float()

        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors_geometric_context = []
        for i_geometric_context in range(n_geometric_context):
            ref_ego_mask_tensors_geometric_context.append([])
        for i in range(self.n):
            _, _, H, W = images[i].shape
            if W < W_full:
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(B, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_geometric_context in range(n_geometric_context):
                    ref_ego_mask_tensors_geometric_context[
                        i_geometric_context].append(
                            interpolate_image(
                                ref_ego_mask_tensor_geometric_context[
                                    i_geometric_context],
                                shape=(B, 1, H, W),
                                mode='nearest',
                                align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_geometric_context in range(n_geometric_context):
                    ref_ego_mask_tensors_geometric_context[
                        i_geometric_context].append(
                            ref_ego_mask_tensor_geometric_context[
                                i_geometric_context])

        # Dummy camera mask (B x n_geometric_context)
        Cmask = (camera_type_geometric_context == 2).detach()

        # temporal context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_temporal_context, poses_temporal_context)):
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = \
                self.warp_ref_image(inv_depths,
                                    camera_type,
                                    intrinsics_poly_coeffs,
                                    intrinsics_principal_point,
                                    intrinsics_scale_factors,
                                    intrinsics_K,
                                    intrinsics_k,
                                    intrinsics_p,
                                    ref_image,
                                    pose,
                                    ego_mask_tensors,
                                    camera_type,
                                    intrinsics_poly_coeffs,
                                    intrinsics_principal_point,
                                    intrinsics_scale_factors,
                                    intrinsics_K,
                                    intrinsics_k,
                                    intrinsics_p)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ego_mask_tensors[i])

        # geometric context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_geometric_context, poses_geometric_context)):
            if Cmask[:, j].sum() < B:
                # Calculate warped images
                ref_warped, ref_ego_mask_tensors_warped = \
                    self.warp_ref_image(inv_depths,
                                        camera_type,
                                        intrinsics_poly_coeffs,
                                        intrinsics_principal_point,
                                        intrinsics_scale_factors,
                                        intrinsics_K,
                                        intrinsics_k,
                                        intrinsics_p,
                                        ref_image,
                                        Pose(pose),
                                        ref_ego_mask_tensors_geometric_context[j],
                                        camera_type_geometric_context[:, j],
                                        intrinsics_poly_coeffs_geometric_context[j],
                                        intrinsics_principal_point_geometric_context[j],
                                        intrinsics_scale_factors_geometric_context[j],
                                        intrinsics_K_geometric_context[j],
                                        intrinsics_k_geometric_context[j],
                                        intrinsics_p_geometric_context[j])
                # Calculate and store image loss
                photometric_loss = self.calc_photometric_loss(
                    ref_warped, images)
                for i in range(self.n):
                    photometric_loss[i][Cmask[:, j]] = 0.0
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
                # If using automask
                if self.automask_loss:
                    # Calculate and store unwarped image loss
                    ref_images = match_scales(ref_image, inv_depths, self.n)
                    unwarped_image_loss = self.calc_photometric_loss(
                        ref_images, images)
                    for i in range(self.n):
                        unwarped_image_loss[i][Cmask[:, j]] = 0.0
                        photometric_losses[i].append(
                            unwarped_image_loss[i] * ego_mask_tensors[i] *
                            ref_ego_mask_tensors_geometric_context[j][i])

        # geometric-temporal context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_geometric_context_temporal_context,
                    poses_geometric_context_temporal_context)):
            j_geometric = j // n_temporal_context
            if Cmask[:, j_geometric].sum() < B:
                # Calculate warped images
                ref_warped, ref_ego_mask_tensors_warped = \
                    self.warp_ref_image(inv_depths,
                                        camera_type,
                                        intrinsics_poly_coeffs,
                                        intrinsics_principal_point,
                                        intrinsics_scale_factors,
                                        intrinsics_K,
                                        intrinsics_k,
                                        intrinsics_p,
                                        ref_image,
                                        Pose(pose.mat @ poses_geometric_context[j_geometric]), # ATTENTION A VERIFIER (changement de repere !)
                                        ref_ego_mask_tensors_geometric_context[j_geometric],
                                        camera_type_geometric_context[:, j_geometric],
                                        intrinsics_poly_coeffs_geometric_context[j_geometric],
                                        intrinsics_principal_point_geometric_context[j_geometric],
                                        intrinsics_scale_factors_geometric_context[j_geometric],
                                        intrinsics_K_geometric_context[j_geometric],
                                        intrinsics_k_geometric_context[j_geometric],
                                        intrinsics_p_geometric_context[j_geometric])
                # Calculate and store image loss
                photometric_loss = self.calc_photometric_loss(
                    ref_warped, images)
                for i in range(self.n):
                    photometric_loss[i][Cmask[:, j_geometric]] = 0.0
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
                # If using automask
                if self.automask_loss:
                    # Calculate and store unwarped image loss
                    ref_images = match_scales(ref_image, inv_depths, self.n)
                    unwarped_image_loss = self.calc_photometric_loss(
                        ref_images, images)
                    for i in range(self.n):
                        unwarped_image_loss[i][Cmask[:, j_geometric]] = 0.0
                        photometric_losses[i].append(
                            unwarped_image_loss[i] * ego_mask_tensors[i] *
                            ref_ego_mask_tensors_geometric_context[j_geometric]
                            [i])

        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
        torch.device('cuda')), 0, 3).squeeze(3).unsqueeze(0)
front_next_img_torch = torch.transpose(
    torch.from_numpy(np.array(img_front_next)).float().unsqueeze(0).to(
        torch.device('cuda')), 0, 3).squeeze(3).unsqueeze(0)
left_img_torch = torch.transpose(
    torch.from_numpy(np.array(img_left)).float().unsqueeze(0).to(
        torch.device('cuda')), 0, 3).squeeze(3).unsqueeze(0)
right_img_torch = torch.transpose(
    torch.from_numpy(np.array(img_right)).float().unsqueeze(0).to(
        torch.device('cuda')), 0, 3).squeeze(3).unsqueeze(0)

new_shape = [160, 256]
scale_factor = 160 / float(800)

front_img_torch_small = interpolate_image(front_img_torch,
                                          new_shape,
                                          mode='bilinear',
                                          align_corners=True)
front_next_img_torch_small = interpolate_image(front_next_img_torch,
                                               new_shape,
                                               mode='bilinear',
                                               align_corners=True)
left_img_torch_small = interpolate_image(left_img_torch,
                                         new_shape,
                                         mode='bilinear',
                                         align_corners=True)
right_img_torch_small = interpolate_image(right_img_torch,
                                          new_shape,
                                          mode='bilinear',
                                          align_corners=True)

simulated_depth_small = interpolate_image(simulated_depth,
    def forward(self, image, context, inv_depths, ref_inv_depths,
                path_to_theta_lut,     path_to_ego_mask,     poly_coeffs,     principal_point,     scale_factors,
                ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point, ref_scale_factors, # ALL LISTS !!!
                same_timestep_as_origin,
                pose_matrix_context,
                poses, return_logs=False, progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)


        n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                ego_mask_tensor[b, 0] = torch.from_numpy(np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(np.load(ref_path_to_ego_mask[i_context][b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            B, C, H, W = images[i].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor, shape=(B, 1, H, W), mode='nearest', align_corners=None)
                    #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor)
                )
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(B, 1, H, W),
                                          mode='nearest',
                                          align_corners=None)
                        #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                    )
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            B, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                ref_ego_mask_tensor[i_context] = interpolate_image(ref_ego_mask_tensor[i_context],
                                                                   shape=(B, 1, H, W),
                                                                   mode='nearest',
                                                                   align_corners=None)
                #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])

        B = len(path_to_ego_mask)

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            ref_warped, ref_ego_mask_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped \
                = self.warp_ref(inv_depths, ref_image, ref_ego_mask_tensor[j], ref_inv_depths[j],
                                path_to_theta_lut, path_to_ego_mask, poly_coeffs, principal_point, scale_factors,
                                ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j],
                                ref_principal_point[j], ref_scale_factors[j],
                                same_timestep_as_origin[j], pose_matrix_context[j],
                                pose,
                                warp_ref_depth=self.use_ref_depth,
                                allow_context_rotation=self.allow_context_rotation)

            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            if (self.mask_occlusion or self.mask_disocclusion) and (self.mask_spatial_context or self.mask_temporal_context):
                no_occlusion_masks    = [(inv_depths_wrt_ref_cam[i] <= self.mult_margin_occlusion * ref_inv_depths_warped[i])
                                         + (inv2depth(ref_inv_depths_warped[i]) <= self.add_margin_occlusion  + inv2depth(inv_depths_wrt_ref_cam[i]))
                                         for i in range(self.n)] # boolean OR
                no_disocclusion_masks = [(ref_inv_depths_warped[i] <= self.mult_margin_occlusion * inv_depths_wrt_ref_cam[i])
                                         + (inv2depth(inv_depths_wrt_ref_cam[i]) <= self.add_margin_occlusion  + inv2depth(ref_inv_depths_warped[i]))
                                         for i in range(self.n)] # boolean OR
                if self.mask_occlusion and self.mask_disocclusion:
                    valid_pixels_occ = [(no_occlusion_masks[i] * no_disocclusion_masks[i]).float() for i in range(self.n)]
                elif self.mask_occlusion:
                    valid_pixels_occ = [no_occlusion_masks[i].float() for i in range(self.n)]
                elif self.mask_disocclusion:
                    valid_pixels_occ = [no_disocclusion_masks[i].float() for i in range(self.n)]
                for b in range(B):
                    if (same_timestep_as_origin[j][b] and not self.mask_spatial_context) or (not same_timestep_as_origin[j][b] and not self.mask_temporal_context):
                        valid_pixels_occ[i][b, :, :, :] = 1.0
            else:
                valid_pixels_occ = [torch.ones_like(inv_depths[i]) for i in range(self.n)]

            if self.depth_consistency_weight > 0.0:
                consistency_tensors_1 = [self.depth_consistency_weight * inv_depths_wrt_ref_cam[i] * torch.abs(inv2depth(inv_depths_wrt_ref_cam[i]) - inv2depth(ref_inv_depths_warped[i])) for i in range(self.n)]
                consistency_tensors_2 = [self.depth_consistency_weight * ref_inv_depths_warped[i]  * torch.abs(inv2depth(inv_depths_wrt_ref_cam[i]) - inv2depth(ref_inv_depths_warped[i])) for i in range(self.n)]
                consistency_tensors = [torch.cat([consistency_tensors_1[i], consistency_tensors_2[i]],1).min(1, True)[0] for i in range(self.n)]
            else:
                consistency_tensors = [torch.zeros_like(inv_depths[i]) for i in range(self.n)]

            for i in range(self.n):
                photometric_losses[i].append((photometric_loss[i] + consistency_tensors[i]) * ego_mask_tensors[i] * ref_ego_mask_tensors_warped[i] * valid_pixels_occ[i])

            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors[j][i])

        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss([a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                                              [a * b for a, b in zip(images,     ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                image,
                context,
                inv_depths,
                K,
                k,
                p,
                path_to_ego_mask,
                ref_K,
                ref_k,
                ref_p,
                path_to_ego_mask_context,
                poses,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_context = len(context)

        if self.mask_ego:
            device = image.get_device()
            B = len(path_to_ego_mask)
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

            ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
            ref_ego_mask_tensor = []
            for i_context in range(n_context):
                ref_ego_mask_tensor.append(
                    torch.ones(B, 1, H_full, W_full).to(device))

            for b in range(B):
                ego_mask_tensor[b, 0] = torch.from_numpy(
                    np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(
                            path_to_ego_mask_context[i_context][b])).float()

            # resized masks
            ego_mask_tensors = []
            ref_ego_mask_tensors = []
            for i_context in range(n_context):
                ref_ego_mask_tensors.append([])
            for i in range(self.n):
                Btmp, C, H, W = images[i].shape
                if W < W_full:
                    # inv_scale_factor = int(W_full / W)
                    # print(W_full / W)
                    # ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                    ego_mask_tensors.append(
                        interpolate_image(ego_mask_tensor,
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
                    for i_context in range(n_context):
                        # ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                        ref_ego_mask_tensors[i_context].append(
                            interpolate_image(ref_ego_mask_tensor[i_context],
                                              shape=(Btmp, 1, H, W),
                                              mode='nearest',
                                              align_corners=None))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(
                            ref_ego_mask_tensor[i_context])

            for i_context in range(n_context):
                _, C, H, W = context[i_context].shape
                if W < W_full:
                    inv_scale_factor = int(W_full / W)
                    # ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                    ref_ego_mask_tensor[i_context] = interpolate_image(
                        ref_ego_mask_tensor[i_context],
                        shape=(Btmp, 1, H, W),
                        mode='nearest',
                        align_corners=None)

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            if self.mask_ego:
                ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image_tensor(
                    inv_depths, ref_image, K, k, p, ref_K, ref_k, ref_p,
                    ref_ego_mask_tensor[j], pose)
                # if torch.isnan(ref_warped[0]).sum() > 0:
                #     print('ref_warped')
                #     print(ref_warped[0])
                #     print(torch.isnan(ref_warped[0]).sum())
                #     print(path_to_ego_mask)
                #     print(torch.isnan(ref_warped[0]).sum(dim=0))
                #     B, _, H, W = ref_image.shape
                #     device = ref_image.get_device()
                #     # Generate cameras for all scales
                #     cams, ref_cams = [], []
                #     for i in range(self.n):
                #         _, _, DH, DW = inv_depths[i].shape
                #         scale_factor = DW / float(W)
                #         cams.append(CameraDistorted(K=K.float(), k1=k[:, 0], k2=k[:, 1], k3=k[:, 2], p1=p[:, 0],
                #                                     p2=p[:, 1]).scaled(scale_factor).to(device))
                #         ref_cams.append(CameraDistorted(K=ref_K.float(), k1=ref_k[:, 0], k2=ref_k[:, 1], k3=ref_k[:, 2],
                #                                         p1=ref_p[:, 0], p2=ref_p[:, 1], Tcw=pose).scaled(
                #             scale_factor).to(device))
                #     # View synthesis
                #     depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
                #     ref_images = match_scales(ref_image, inv_depths, self.n)
                #     ref_warped = [view_synthesis(
                #         ref_images[i], depths[i], ref_cams[i], cams[i],
                #         padding_mode=self.padding_mode) for i in range(self.n)]
                #
                #     for i in range(self.n):
                #         world_points = cams[i].reconstruct(depths[i], frame='w')
                #         ref_coords = ref_cams[i].project(world_points, frame='w')
                #         print(i)
                #         print('world_points')
                #         print('min')
                #         print(torch.min(world_points[:, 0, :, :]))
                #         print(torch.min(world_points[:, 1, :, :]))
                #         print(torch.min(world_points[:, 2, :, :]))
                #         print('max')
                #         print(torch.max(world_points[:, 0, :, :]))
                #         print(torch.max(world_points[:, 1, :, :]))
                #         print(torch.max(world_points[:, 2, :, :]))
                #         print('ref_coords')
                #         print('min')
                #         print(torch.min(ref_coords[:, :, :, 0]))
                #         print(torch.min(ref_coords[:, :, :, 1]))
                #         print('max')
                #         print(torch.max(ref_coords[:, :, :, 0]))
                #         print(torch.max(ref_coords[:, :, :, 1]))

            else:
                ref_warped = self.warp_ref_image(inv_depths, ref_image, K, k,
                                                 p, ref_K, ref_k, ref_p, pose)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            if self.mask_ego:
                for i in range(self.n):
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
            else:
                for i in range(self.n):
                    photometric_losses[i].append(photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }