def forward(self, image, context, inv_depths,
                K, ref_K, poses, return_logs=False, progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            ref_warped = self.warp_ref_image(inv_depths, ref_image, K, ref_K, pose)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i])
        # Calculate reduced photometric loss
        loss = self.reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def warp_ref_image(self, inv_depths, ref_image, path_to_theta_lut,
                       poly_coeffs, principal_point, scale_factor_y,
                       ref_path_to_theta_lut, ref_poly_coeffs,
                       ref_principal_point, ref_scale_factor_y, pose):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(
                CameraFisheyeWoodscape(
                    path_to_theta_lut=path_to_theta_lut,
                    poly_coeffs=poly_coeffs.float(),
                    principal_point=principal_point.float(),
                    scale_factor_y=scale_factor_y.float()).scaled(
                        scale_factor).to(device))
            ref_cams.append(
                CameraFisheyeWoodscape(
                    path_to_theta_lut=ref_path_to_theta_lut,
                    poly_coeffs=ref_poly_coeffs.float(),
                    principal_point=ref_principal_point.float(),
                    scale_factor_y=ref_scale_factor_y.float(),
                    Tcw=pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis(ref_images[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode)
            for i in range(self.n)
        ]
        # Return warped reference image
        return ref_warped
Beispiel #3
0
    def warp_ref_image_spatial(self, inv_depths, ref_image, K, ref_K,
                               extrinsics_1, extrinsics_2):
        """
        Warps a reference image to produce a reconstruction of the original one (spatial-wise).

        Parameters
        ----------
        inv_depths : torch.Tensor [6,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [6,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        extrinsics_1: torch.Tensor [B,4,4]
            target image extrinsics
        extrinsics_2: torch.Tensor [B,4,4]
            context image extrinsics

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        valid_points_mask :
            valid points mask
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(
                Camera(K=K.float(),
                       Tcw=extrinsics_1).scaled(scale_factor).to(device))
            ref_cams.append(
                Camera(K=ref_K.float(),
                       Tcw=extrinsics_2).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = []
        ref_coords = []
        for i in range(self.n):
            w, c = view_synthesis(ref_images[i],
                                  depths[i],
                                  ref_cams[i],
                                  cams[i],
                                  padding_mode=self.padding_mode)
            ref_warped.append(w)
            ref_coords.append(c)
        # calculate valid_points_mask
        valid_points_masks = [
            ref_coords[i].abs().max(dim=-1)[0] <= 1 for i in range(self.n)
        ]
        return ref_warped, valid_points_masks
    def warp_ref_image(self, inv_depths, ref_image, K, k, p, ref_K, ref_k,
                       ref_p, pose):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(
                CameraDistorted(K=K.float(),
                                k1=k[:, 0],
                                k2=k[:, 1],
                                k3=k[:, 2],
                                p1=p[:, 0],
                                p2=p[:, 1]).scaled(scale_factor).to(device))
            ref_cams.append(
                CameraDistorted(K=ref_K.float(),
                                k1=ref_k[:, 0],
                                k2=ref_k[:, 1],
                                k3=ref_k[:, 2],
                                p1=ref_p[:, 0],
                                p2=ref_p[:, 1],
                                Tcw=pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis(ref_images[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode)
            for i in range(self.n)
        ]
        # Return warped reference image
        return ref_warped
Beispiel #5
0
    def warp_ref_image(self, inv_depths, ref_image, raysurf_residual, pose,
                       progress):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = torch.device('cpu')  #ref_image.get_device()
        # Generate cameras for all scales

        coeff = np.min([((100.0 * progress)**(4 / 3.) / 100.), 1.])
        Rmat = self.canonical_ray_surface.to(device) + coeff * raysurf_residual
        Rmat = Rmat / torch.norm(Rmat, dim=1, keepdim=True)

        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(GenericCamera_cpu(R=Rmat).to(device))
            ref_cams.append(GenericCamera_cpu(R=Rmat, Tcw=pose).to(device))

        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]

        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis_generic(ref_images[i],
                                   depths[i],
                                   ref_cams[i],
                                   cams[i],
                                   padding_mode=self.padding_mode,
                                   progress=progress) for i in range(self.n)
        ]
        # Return warped reference image
        return ref_warped
    def warp_ref_image(self, inv_depths, ref_image, K, ref_K, pose):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        #print('Taille image référence')
        #print(ref_image.shape)
        #print('Warping des images , tailles :')
        device = torch.device('cpu')#ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            #print(inv_depths[i].shape)
            scale_factor = DW / float(W)
            #print(scale_factor)
            cams.append(Camera(K=K.float()).scaled(scale_factor).to(device))
            ref_cams.append(Camera(K=ref_K.float(), Tcw=pose).scaled(scale_factor).to(device))
            #print(Camera(K=K.float()).scaled(scale_factor).K)
            #print(Camera(K=ref_K.float()).scaled(scale_factor).K)
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [view_synthesis(
            ref_images[i], depths[i], ref_cams[i], cams[i],
            padding_mode=self.padding_mode) for i in range(self.n)]
        # Return warped reference image
        return ref_warped
Beispiel #7
0
    def forward(self,
                inv_depths,
                gt_inv_depth,
                return_logs=False,
                progress=0.0):
        """
        Calculates training supervised loss.

        Parameters
        ----------
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        gt_inv_depth : torch.Tensor [B,1,H,W]
            Ground-truth depth map for the original image
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Match predicted scales for ground-truth
        gt_inv_depths = match_scales(gt_inv_depth,
                                     inv_depths,
                                     self.n,
                                     mode='nearest',
                                     align_corners=None)
        # Calculate and store supervised loss
        loss = self.calculate_loss(inv_depths, gt_inv_depths)
        self.add_metric('supervised_loss', loss)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(
            self,
            image,
            context,
            inv_depths,
            path_to_theta_lut,
            path_to_ego_mask,
            poly_coeffs,
            principal_point,
            scale_factors,
            ref_path_to_theta_lut,
            ref_path_to_ego_mask,
            ref_poly_coeffs,
            ref_principal_point,
            ref_scale_factors,  # ALL LISTS !!!
            same_timestep_as_origin,
            pose_matrix_context,
            poses,
            return_logs=False,
            progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_context = len(context)

        if self.mask_ego:
            device = image.get_device()

            B = len(path_to_ego_mask)

            ego_mask_tensor = torch.zeros(B, 1, 800, 1280).to(device)
            ref_ego_mask_tensor = [
            ]  #[torch.zeros(B, 1, 800, 1280).to(device)] * n_context
            for i_context in range(n_context):
                ref_ego_mask_tensor.append(
                    torch.zeros(B, 1, 800, 1280).to(device))
            for b in range(B):
                ego_mask_tensor[b, 0] = torch.from_numpy(
                    np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(ref_path_to_ego_mask[i_context][b])).float()

            ego_mask_tensors = []  # = torch.zeros(B, 1, 800, 1280)
            ref_ego_mask_tensors = [
            ]  #[[]] *  n_context  # = torch.zeros(B, 1, 800, 1280)
            for i_context in range(n_context):
                ref_ego_mask_tensors.append([])
            for i in range(self.n):
                B, C, H, W = images[i].shape
                if W < 1280:
                    inv_scale_factor = int(1280 / W)
                    ego_mask_tensors.append(-nn.MaxPool2d(
                        inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(
                            -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)
                            (-ref_ego_mask_tensor[i_context]))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(
                            ref_ego_mask_tensor[i_context])
                # ego_mask_tensor = ego_mask_tensor.to(device)

            for i_context in range(n_context):
                B, C, H, W = context[i_context].shape
                if W < 1280:
                    inv_scale_factor = int(1280 / W)
                    ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(
                        inv_scale_factor,
                        inv_scale_factor)(-ref_ego_mask_tensor[i_context])

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            if self.mask_ego:
                if self.warp_ego_tensor:
                    ref_warped, ref_ego_mask_tensors_warped \
                        = self.warp_ref_image_tensor(inv_depths, ref_image, ref_ego_mask_tensor[j],
                                                     path_to_theta_lut,        path_to_ego_mask,        poly_coeffs,        principal_point,        scale_factors,
                                                     ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j], ref_principal_point[j], ref_scale_factors[j],
                                                     same_timestep_as_origin[j],
                                                     pose_matrix_context[j],
                                                     pose)
                    photometric_loss = self.calc_photometric_loss([
                        a * b * c
                        for a, b, c in zip(ref_warped, ego_mask_tensors,
                                           ref_ego_mask_tensors_warped)
                    ], [a * b for a, b in zip(images, ego_mask_tensors)])

                else:
                    ref_warped = self.warp_ref_image(
                        inv_depths, ref_image * ref_ego_mask_tensor[j],
                        path_to_theta_lut, path_to_ego_mask, poly_coeffs,
                        principal_point, scale_factors,
                        ref_path_to_theta_lut[j], ref_path_to_ego_mask[j],
                        ref_poly_coeffs[j], ref_principal_point[j],
                        ref_scale_factors[j], same_timestep_as_origin[j],
                        pose_matrix_context[j], pose)
                    photometric_loss = self.calc_photometric_loss(
                        [a * b for a, b in zip(ref_warped, ego_mask_tensors)],
                        [a * b for a, b in zip(images, ego_mask_tensors)])
            else:
                ref_warped = self.warp_ref_image(
                    inv_depths, ref_image, path_to_theta_lut, path_to_ego_mask,
                    poly_coeffs, principal_point, scale_factors,
                    ref_path_to_theta_lut[j], ref_path_to_ego_mask[j],
                    ref_poly_coeffs[j], ref_principal_point[j],
                    ref_scale_factors[j], same_timestep_as_origin[j],
                    pose_matrix_context[j], pose)
                photometric_loss = self.calc_photometric_loss(
                    ref_warped, images)
            # tt = str(int(time.time() % 10000))
            # for i in range(self.n):
            #     B, C, H, W = images[i].shape
            #     print(photometric_loss[i][:,:,::20,::20])
            #     for b in range(B):
            #         orig_PIL_0   = torch.transpose((ref_image[b,:,:,:]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #         orig_PIL   = torch.transpose((ref_image[b,:,:,:]* ref_ego_mask_tensor[j][b,:,:,:]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #         warped_PIL_0 = torch.transpose((ref_warped[i][b,:,:,:]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #         warped_PIL = torch.transpose((ref_warped[i][b,:,:,:] * ego_mask_tensors[i][b,:,:,:]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #         target_PIL_0 = torch.transpose((images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #         target_PIL = torch.transpose((images[i][b, :, :, :] * ego_mask_tensors[i][b,:,:,:]).unsqueeze(0).unsqueeze(4), 1, 4).squeeze().detach().cpu().numpy()
            #
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png',   orig_PIL_0*255)
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png',   orig_PIL*255)
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png', warped_PIL_0 * 255)
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png', warped_PIL * 255)
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png', target_PIL_0 * 255)
            #         cv2.imwrite('/home/users/vbelissen/test' + '_' + str(j) + '_' + tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png', target_PIL * 255)
            #
            # print(photometric_loss[0].shape)

            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                if self.mask_ego:
                    unwarped_image_loss = self.calc_photometric_loss([
                        a * b
                        for a, b in zip(ref_images, ref_ego_mask_tensors[j])
                    ], [a * b for a, b in zip(images, ego_mask_tensors)])
                else:
                    unwarped_image_loss = self.calc_photometric_loss(
                        ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i])
        # Calculate reduced photometric loss
        loss = self.reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            if self.mask_ego:
                loss += self.calc_smoothness_loss(
                    [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                    [a * b for a, b in zip(images, ego_mask_tensors)])
            else:
                loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
Beispiel #9
0
    def forward(self,
                image,
                context,
                inv_depths,
                K,
                ref_K,
                extrinsics,
                poses,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.
        (Here we need to consider temporal, spatial and temporal-spatial wise losses)

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """

        # Step 0: prepare for loss calculation
        # Reorganize the order of camearas -- (See more API information about the datasets)
        image = self.sort_cameras_tensor(image)
        context = self.sort_cameras_tensor(context)
        inv_depths = self.sort_cameras_tensor(inv_depths)
        K = self.sort_cameras_tensor(K)
        ref_K = K
        poses = [
            Pose(self.sort_cameras_tensor(poses[0].item())),
            Pose(self.sort_cameras_tensor(poses[1].item()))
        ]

        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)
        extrinsics = torch.tensor(extrinsics,
                                  dtype=torch.float32,
                                  device="cuda")
        # Step 1: Calculate the losses temporal-wise
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            ref_warped = self.warp_ref_image_temporal(inv_depths, ref_image, K,
                                                      ref_K, pose)

            # Calculate and store image loss

            # print('### poses shape', len(poses))
            # print('poses[0].shape:', poses[0].shape)
            # print('###multiview_photometric_loss printing ref_warped')
            # print('len of images: ',len(images))
            # print('shape of images[0]: ', images[0].shape)
            # print('len of context: ',len(context))
            # print('shape of context[0]:', context[0].shape)
            # print('len of ref_warped: ',len(ref_warped))
            # print('shape of ref_warped[0]', ref_warped[0].shape)

            # pic_orig = images[0][0].cpu().clone()
            # pic_orig = (pic_orig.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
            # pic_ref = context[0][0].cpu().clone()
            # pic_ref = (pic_ref.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
            # pic_warped = ref_warped[0][0].cpu().clone()
            # pic_warped = (pic_warped.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
            # final_frame = cv2.hconcat((pic_orig, pic_ref, pic_warped))
            # cv2.imshow('temporal warping', final_frame)
            # cv2.waitKey()

            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            for i in range(self.n):
                photometric_losses[i].append(self.temporal_loss_weight *
                                             photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(self.temporal_loss_weight *
                                                 unwarped_image_loss[i])

        # Step 2: Calculate the losses spatial-wise
        # reconstruct context images
        num_cameras, C, H, W = image.shape  # should be (6, 3, H, W)
        left_swap = [i for i in range(-1, num_cameras - 1)]
        right_swap = [i % 6 for i in range(1, num_cameras + 1)]

        # for i in range(num_cameras):
        #     pic = image[i].cpu().clone()
        #     pic = (pic.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
        #     if i == 0:
        #         final_file = pic
        #     else:
        #         final_file = cv2.hconcat((final_file, pic))
        # cv2.imshow('6 images', final_file)
        # cv2.waitKey()

        context_spatial = [[], []]  # [[B,3,H,W],[B,3,H,W]]
        context_spatial[0] = image[left_swap, ...]
        context_spatial[1] = image[right_swap, ...]
        K_spatial = K  # tensor [B,3,3]
        ref_K_spatial = [[], []]  # [[B,3,3],[B,3,3]]
        ref_K_spatial[0] = K[left_swap, ...]
        ref_K_spatial[1] = K[right_swap, ...]
        # reconstruct extrinsics
        extrinsics_1 = extrinsics  # [B,4,4]
        extrinsics_2 = [[], []]  # [[B,4,4],[B,4,4]]
        extrinsics_2[0] = extrinsics_1[left_swap, ...]
        extrinsics_2[1] = extrinsics_1[right_swap, ...]
        # calculate spatial-wise photometric loss
        for j, ref_image in enumerate(context_spatial):
            # Calculate warped images
            ref_warped, valid_points_masks = self.warp_ref_image_spatial(
                inv_depths, ref_image, K_spatial, ref_K_spatial[j],
                Pose(extrinsics_1), Pose(extrinsics_2[j]))

            # pic_orig = images[0][1].cpu().clone()
            # pic_orig = (pic_orig.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
            # pic_ref = context_spatial[0][1].cpu().clone()
            # pic_ref = (pic_ref.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
            # pic_warped = ref_warped[0][1].cpu().clone()
            # pic_warped = (pic_warped.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
            # pic_valid = valid_points_masks[0][1].cpu().clone()
            # pic_valid = (pic_valid.permute(1,2,0).detach().numpy()*255).astype(np.uint8)
            # final_frame = cv2.hconcat((pic_orig, pic_ref, pic_warped))
            # cv2.imshow('spatial warping', final_frame)
            # cv2.waitKey()
            # # cv2.imshow('pic_valid', pic_valid)
            # # cv2.waitKey()

            # Calculate and store image loss
            photometric_loss = [
                self.calc_photometric_loss(ref_warped, images)[i] *
                valid_points_masks[i] for i in range(len(valid_points_masks))
            ]
            for i in range(self.n):
                photometric_losses[i].append(self.spatial_loss_weight *
                                             photometric_loss[i])

        # Step 3: Calculate the loss temporal-spatial wise
        # reconstruct context images
        context_temporal_spatial = []
        # [context_temporal_spatial_backward, context_temporal_spatial_forward]
        # [[left t-1, right t-1], [left t+1, right t+1]]
        # [[[B,H,W],[B,H,W]],[[B,H,W],[B,H,W]]]
        # reconstruct intrinsics
        K_temporal_spatial = K
        ref_K_temporal_spatial = []
        # reconstruct extrinsics
        extrinsics_1_ts = extrinsics
        extrinsics_2_ts = []
        # reconstruct pose
        poses_ts = []
        for l in range(len(context)):
            context_temporal_spatial.append(
                [context[l][left_swap, ...], context[l][right_swap, ...]])
            ref_K_temporal_spatial.append([
                K_temporal_spatial[left_swap, ...],
                K_temporal_spatial[right_swap, ...]
            ])
            extrinsics_2_ts.append(
                [extrinsics[left_swap, ...], extrinsics[right_swap, ...]])
            poses_ts.append([
                Pose(poses[l].item()[left_swap, ...]),
                Pose(poses[l].item()[right_swap, ...])
            ])
        # calculate spatial-wise photometric loss
        for j, (ref_image,
                pose) in enumerate(zip(context_temporal_spatial, poses_ts)):
            # Calculate warped images
            for k in range(len(ref_image)):
                ref_warped, valid_points_masks = self.warp_ref_image_spatial(
                    inv_depths, ref_image[k], K_temporal_spatial,
                    ref_K_temporal_spatial[j][k], Pose(extrinsics_1_ts),
                    Pose(extrinsics_2_ts[j][k]) @ pose[k].inverse())

                # for i in range(6):
                #     pic_orig = images[0][i].cpu().clone()
                #     pic_orig = (pic_orig.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
                #     pic_ref = context_spatial[0][i].cpu().clone()
                #     pic_ref = (pic_ref.squeeze(0).permute(1, 2, 0).detach().numpy() * 255).astype(np.uint8)
                #     pic_warped = ref_warped[0][i].cpu().clone()
                #     pic_warped = (pic_warped.squeeze(0).permute(1,2,0).detach().numpy()*255).astype(np.uint8)
                #     pic_valid = valid_points_masks[0][i].cpu().clone()
                #     pic_valid = (pic_valid.permute(1,2,0).detach().numpy()*255).astype(np.uint8)
                #     final_frame = cv2.hconcat((pic_orig, pic_ref, pic_warped))
                #     cv2.imshow('temporal spatial warping', final_frame)
                #     cv2.waitKey()
                #     cv2.imshow('pic_valid', pic_valid)
                #     cv2.waitKey()

                # Calculate and store image loss
                photometric_loss = [self.calc_photometric_loss(ref_warped, images)[i] * valid_points_masks[i] \
                                    for i in range(len(valid_points_masks))]
                for i in range(self.n):
                    photometric_losses[i].append(
                        self.temporal_spatial_loss_weight *
                        photometric_loss[i])

        # Step 4: Calculate reduced photometric loss
        loss = self.reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                image,
                ref_images_temporal_context,
                ref_images_geometric_context,
                ref_images_geometric_context_temporal_context,
                inv_depths,
                poses_temporal_context,
                poses_geometric_context,
                poses_geometric_context_temporal_context,
                camera_type,
                intrinsics_poly_coeffs,
                intrinsics_principal_point,
                intrinsics_scale_factors,
                intrinsics_K,
                intrinsics_k,
                intrinsics_p,
                path_to_ego_mask,
                camera_type_geometric_context,
                intrinsics_poly_coeffs_geometric_context,
                intrinsics_principal_point_geometric_context,
                intrinsics_scale_factors_geometric_context,
                intrinsics_K_geometric_context,
                intrinsics_k_geometric_context,
                intrinsics_p_geometric_context,
                path_to_ego_mask_geometric_context,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_temporal_context = len(ref_images_temporal_context)
        n_geometric_context = len(ref_images_geometric_context)
        assert len(ref_images_geometric_context_temporal_context
                   ) == n_temporal_context * n_geometric_context
        B = len(path_to_ego_mask)

        device = image.get_device()

        # getting ego masks for target and source cameras
        # fullsize mask
        H_full = 800
        W_full = 1280
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor_geometric_context = []
        for i_geometric_context in range(n_geometric_context):
            ref_ego_mask_tensor_geometric_context.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            ego_mask_tensor[b, 0] = torch.from_numpy(
                np.load(path_to_ego_mask[b])).float()
            for i_geometric_context in range(n_geometric_context):
                if camera_type_geometric_context[b, i_geometric_context] != 2:
                    ref_ego_mask_tensor_geometric_context[i_geometric_context][b, 0] = \
                        torch.from_numpy(np.load(path_to_ego_mask_geometric_context[i_geometric_context][b])).float()

        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors_geometric_context = []
        for i_geometric_context in range(n_geometric_context):
            ref_ego_mask_tensors_geometric_context.append([])
        for i in range(self.n):
            _, _, H, W = images[i].shape
            if W < W_full:
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(B, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_geometric_context in range(n_geometric_context):
                    ref_ego_mask_tensors_geometric_context[
                        i_geometric_context].append(
                            interpolate_image(
                                ref_ego_mask_tensor_geometric_context[
                                    i_geometric_context],
                                shape=(B, 1, H, W),
                                mode='nearest',
                                align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_geometric_context in range(n_geometric_context):
                    ref_ego_mask_tensors_geometric_context[
                        i_geometric_context].append(
                            ref_ego_mask_tensor_geometric_context[
                                i_geometric_context])

        # Dummy camera mask (B x n_geometric_context)
        Cmask = (camera_type_geometric_context == 2).detach()

        # temporal context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_temporal_context, poses_temporal_context)):
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = \
                self.warp_ref_image(inv_depths,
                                    camera_type,
                                    intrinsics_poly_coeffs,
                                    intrinsics_principal_point,
                                    intrinsics_scale_factors,
                                    intrinsics_K,
                                    intrinsics_k,
                                    intrinsics_p,
                                    ref_image,
                                    pose,
                                    ego_mask_tensors,
                                    camera_type,
                                    intrinsics_poly_coeffs,
                                    intrinsics_principal_point,
                                    intrinsics_scale_factors,
                                    intrinsics_K,
                                    intrinsics_k,
                                    intrinsics_p)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ego_mask_tensors[i])

        # geometric context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_geometric_context, poses_geometric_context)):
            if Cmask[:, j].sum() < B:
                # Calculate warped images
                ref_warped, ref_ego_mask_tensors_warped = \
                    self.warp_ref_image(inv_depths,
                                        camera_type,
                                        intrinsics_poly_coeffs,
                                        intrinsics_principal_point,
                                        intrinsics_scale_factors,
                                        intrinsics_K,
                                        intrinsics_k,
                                        intrinsics_p,
                                        ref_image,
                                        Pose(pose),
                                        ref_ego_mask_tensors_geometric_context[j],
                                        camera_type_geometric_context[:, j],
                                        intrinsics_poly_coeffs_geometric_context[j],
                                        intrinsics_principal_point_geometric_context[j],
                                        intrinsics_scale_factors_geometric_context[j],
                                        intrinsics_K_geometric_context[j],
                                        intrinsics_k_geometric_context[j],
                                        intrinsics_p_geometric_context[j])
                # Calculate and store image loss
                photometric_loss = self.calc_photometric_loss(
                    ref_warped, images)
                for i in range(self.n):
                    photometric_loss[i][Cmask[:, j]] = 0.0
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
                # If using automask
                if self.automask_loss:
                    # Calculate and store unwarped image loss
                    ref_images = match_scales(ref_image, inv_depths, self.n)
                    unwarped_image_loss = self.calc_photometric_loss(
                        ref_images, images)
                    for i in range(self.n):
                        unwarped_image_loss[i][Cmask[:, j]] = 0.0
                        photometric_losses[i].append(
                            unwarped_image_loss[i] * ego_mask_tensors[i] *
                            ref_ego_mask_tensors_geometric_context[j][i])

        # geometric-temporal context
        for j, (ref_image, pose) in enumerate(
                zip(ref_images_geometric_context_temporal_context,
                    poses_geometric_context_temporal_context)):
            j_geometric = j // n_temporal_context
            if Cmask[:, j_geometric].sum() < B:
                # Calculate warped images
                ref_warped, ref_ego_mask_tensors_warped = \
                    self.warp_ref_image(inv_depths,
                                        camera_type,
                                        intrinsics_poly_coeffs,
                                        intrinsics_principal_point,
                                        intrinsics_scale_factors,
                                        intrinsics_K,
                                        intrinsics_k,
                                        intrinsics_p,
                                        ref_image,
                                        Pose(pose.mat @ poses_geometric_context[j_geometric]), # ATTENTION A VERIFIER (changement de repere !)
                                        ref_ego_mask_tensors_geometric_context[j_geometric],
                                        camera_type_geometric_context[:, j_geometric],
                                        intrinsics_poly_coeffs_geometric_context[j_geometric],
                                        intrinsics_principal_point_geometric_context[j_geometric],
                                        intrinsics_scale_factors_geometric_context[j_geometric],
                                        intrinsics_K_geometric_context[j_geometric],
                                        intrinsics_k_geometric_context[j_geometric],
                                        intrinsics_p_geometric_context[j_geometric])
                # Calculate and store image loss
                photometric_loss = self.calc_photometric_loss(
                    ref_warped, images)
                for i in range(self.n):
                    photometric_loss[i][Cmask[:, j_geometric]] = 0.0
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
                # If using automask
                if self.automask_loss:
                    # Calculate and store unwarped image loss
                    ref_images = match_scales(ref_image, inv_depths, self.n)
                    unwarped_image_loss = self.calc_photometric_loss(
                        ref_images, images)
                    for i in range(self.n):
                        unwarped_image_loss[i][Cmask[:, j_geometric]] = 0.0
                        photometric_losses[i].append(
                            unwarped_image_loss[i] * ego_mask_tensors[i] *
                            ref_ego_mask_tensors_geometric_context[j_geometric]
                            [i])

        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def warp_ref_image(self, inv_depths, camera_type, intrinsics_poly_coeffs,
                       intrinsics_principal_point, intrinsics_scale_factors,
                       intrinsics_K, intrinsics_k, intrinsics_p, ref_image,
                       ref_pose, ref_ego_mask_tensors, ref_camera_type,
                       ref_intrinsics_poly_coeffs,
                       ref_intrinsics_principal_point,
                       ref_intrinsics_scale_factors, ref_intrinsics_K,
                       ref_intrinsics_k, ref_intrinsics_p):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(
                CameraMultifocal(intrinsics_poly_coeffs,
                                 intrinsics_principal_point,
                                 intrinsics_scale_factors,
                                 intrinsics_K,
                                 intrinsics_k[:, 0],
                                 intrinsics_k[:, 1],
                                 intrinsics_k[:, 2],
                                 intrinsics_p[:, 0],
                                 intrinsics_p[:, 1],
                                 camera_type,
                                 Tcw=None).scaled(scale_factor).to(device))
            ref_cams.append(
                CameraMultifocal(ref_intrinsics_poly_coeffs,
                                 ref_intrinsics_principal_point,
                                 ref_intrinsics_scale_factors,
                                 ref_intrinsics_K,
                                 ref_intrinsics_k[:, 0],
                                 ref_intrinsics_k[:, 1],
                                 ref_intrinsics_k[:, 2],
                                 ref_intrinsics_p[:, 0],
                                 ref_intrinsics_p[:, 1],
                                 ref_camera_type,
                                 Tcw=ref_pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis(ref_images[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode)
            for i in range(self.n)
        ]

        ref_tensors_warped = [
            view_synthesis(ref_ego_mask_tensors[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode,
                           mode='nearest') for i in range(self.n)
        ]
        # Return warped reference image
        return ref_warped, ref_tensors_warped
Beispiel #12
0
    def forward(self, image, context, inv_depths,
                path_to_theta_lut,     path_to_ego_mask,     poly_coeffs,     principal_point,     scale_factors,
                ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point, ref_scale_factors, # ALL LISTS !!!
                poses, return_logs=False, progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_context = len(context)

        if self.mask_ego:
            device = image.get_device()

            B = len(path_to_ego_mask)

            ego_mask_tensor     = torch.zeros(B, 1, 800, 1280).to(device)
            ref_ego_mask_tensor = []#[torch.zeros(B, 1, 800, 1280).to(device)] * n_context
            for i_context in range(n_context):
                ref_ego_mask_tensor.append(torch.zeros(B, 1, 800, 1280).to(device))
            for b in range(B):
                ego_mask_tensor[b, 0]     = torch.from_numpy(np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(np.load(ref_path_to_ego_mask[i_context][b])).float()

            ego_mask_tensors     = []  # = torch.zeros(B, 1, 800, 1280)
            ref_ego_mask_tensors = []#[[]] *  n_context  # = torch.zeros(B, 1, 800, 1280)
            for i_context in range(n_context):
                ref_ego_mask_tensors.append([])
            for i in range(self.n):
                B, C, H, W = images[i].shape
                if W < 1280:
                    inv_scale_factor = int(1280 / W)
                    ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(ref_ego_mask_tensor[i_context])
                # ego_mask_tensor = ego_mask_tensor.to(device)

            for i_context in range(n_context):
                B, C, H, W = context[i_context].shape
                if W < 1280:
                    inv_scale_factor = int(1280 / W)
                    ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            if self.mask_ego:
                ref_warped = self.warp_ref_image(inv_depths, ref_image * ref_ego_mask_tensor[j],
                                                 path_to_theta_lut,        path_to_ego_mask,        poly_coeffs,        principal_point,        scale_factors,
                                                 ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j], ref_principal_point[j], ref_scale_factors[j],
                                                 pose)
                photometric_loss = self.calc_photometric_loss([a * b for a, b in zip(ref_warped, ego_mask_tensors)],
                                                              [a * b for a, b in zip(images,     ego_mask_tensors)])
            else:
                ref_warped = self.warp_ref_image(inv_depths, ref_image,
                                                 path_to_theta_lut,        path_to_ego_mask,        poly_coeffs,        principal_point,        scale_factors,
                                                 ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j], ref_principal_point[j], ref_scale_factors[j],
                                                 pose)
                photometric_loss = self.calc_photometric_loss(ref_warped, images)
            # Calculate and store image loss

            # if self.mask_ego:
            #     photometric_loss = self.calc_photometric_loss([a * b for a, b in zip(ref_warped, ref_ego_mask_tensors[j])],
            #                                                   [a * b for a, b in zip(images,     ego_mask_tensors)])
            # else:
            #     photometric_loss = self.calc_photometric_loss(ref_warped, images)
            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                #unwarped_image_loss = self.calc_photometric_loss(ref_images, images)
                if self.mask_ego:
                    unwarped_image_loss = self.calc_photometric_loss([a * b for a, b in zip(ref_images, ref_ego_mask_tensors[j])],
                                                                     [a * b for a, b in zip(images,     ego_mask_tensors)])
                else:
                    unwarped_image_loss = self.calc_photometric_loss(ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i])
        # Calculate reduced photometric loss
        loss = self.reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            #loss += self.calc_smoothness_loss(inv_depths, images)
            if self.mask_ego:
                loss += self.calc_smoothness_loss([a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                                                  [a * b for a, b in zip(images,     ego_mask_tensors)])
            else:
                loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                image,
                context,
                inv_depths,
                K,
                k,
                p,
                path_to_ego_mask,
                ref_K,
                ref_k,
                ref_p,
                path_to_ego_mask_context,
                poses,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_context = len(context)

        if self.mask_ego:
            device = image.get_device()
            B = len(path_to_ego_mask)
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

            ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
            ref_ego_mask_tensor = []
            for i_context in range(n_context):
                ref_ego_mask_tensor.append(
                    torch.ones(B, 1, H_full, W_full).to(device))

            for b in range(B):
                ego_mask_tensor[b, 0] = torch.from_numpy(
                    np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(
                            path_to_ego_mask_context[i_context][b])).float()

            # resized masks
            ego_mask_tensors = []
            ref_ego_mask_tensors = []
            for i_context in range(n_context):
                ref_ego_mask_tensors.append([])
            for i in range(self.n):
                Btmp, C, H, W = images[i].shape
                if W < W_full:
                    # inv_scale_factor = int(W_full / W)
                    # print(W_full / W)
                    # ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                    ego_mask_tensors.append(
                        interpolate_image(ego_mask_tensor,
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
                    for i_context in range(n_context):
                        # ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                        ref_ego_mask_tensors[i_context].append(
                            interpolate_image(ref_ego_mask_tensor[i_context],
                                              shape=(Btmp, 1, H, W),
                                              mode='nearest',
                                              align_corners=None))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)
                    for i_context in range(n_context):
                        ref_ego_mask_tensors[i_context].append(
                            ref_ego_mask_tensor[i_context])

            for i_context in range(n_context):
                _, C, H, W = context[i_context].shape
                if W < W_full:
                    inv_scale_factor = int(W_full / W)
                    # ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                    ref_ego_mask_tensor[i_context] = interpolate_image(
                        ref_ego_mask_tensor[i_context],
                        shape=(Btmp, 1, H, W),
                        mode='nearest',
                        align_corners=None)

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            # Calculate warped images
            if self.mask_ego:
                ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image_tensor(
                    inv_depths, ref_image, K, k, p, ref_K, ref_k, ref_p,
                    ref_ego_mask_tensor[j], pose)
                # if torch.isnan(ref_warped[0]).sum() > 0:
                #     print('ref_warped')
                #     print(ref_warped[0])
                #     print(torch.isnan(ref_warped[0]).sum())
                #     print(path_to_ego_mask)
                #     print(torch.isnan(ref_warped[0]).sum(dim=0))
                #     B, _, H, W = ref_image.shape
                #     device = ref_image.get_device()
                #     # Generate cameras for all scales
                #     cams, ref_cams = [], []
                #     for i in range(self.n):
                #         _, _, DH, DW = inv_depths[i].shape
                #         scale_factor = DW / float(W)
                #         cams.append(CameraDistorted(K=K.float(), k1=k[:, 0], k2=k[:, 1], k3=k[:, 2], p1=p[:, 0],
                #                                     p2=p[:, 1]).scaled(scale_factor).to(device))
                #         ref_cams.append(CameraDistorted(K=ref_K.float(), k1=ref_k[:, 0], k2=ref_k[:, 1], k3=ref_k[:, 2],
                #                                         p1=ref_p[:, 0], p2=ref_p[:, 1], Tcw=pose).scaled(
                #             scale_factor).to(device))
                #     # View synthesis
                #     depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
                #     ref_images = match_scales(ref_image, inv_depths, self.n)
                #     ref_warped = [view_synthesis(
                #         ref_images[i], depths[i], ref_cams[i], cams[i],
                #         padding_mode=self.padding_mode) for i in range(self.n)]
                #
                #     for i in range(self.n):
                #         world_points = cams[i].reconstruct(depths[i], frame='w')
                #         ref_coords = ref_cams[i].project(world_points, frame='w')
                #         print(i)
                #         print('world_points')
                #         print('min')
                #         print(torch.min(world_points[:, 0, :, :]))
                #         print(torch.min(world_points[:, 1, :, :]))
                #         print(torch.min(world_points[:, 2, :, :]))
                #         print('max')
                #         print(torch.max(world_points[:, 0, :, :]))
                #         print(torch.max(world_points[:, 1, :, :]))
                #         print(torch.max(world_points[:, 2, :, :]))
                #         print('ref_coords')
                #         print('min')
                #         print(torch.min(ref_coords[:, :, :, 0]))
                #         print(torch.min(ref_coords[:, :, :, 1]))
                #         print('max')
                #         print(torch.max(ref_coords[:, :, :, 0]))
                #         print(torch.max(ref_coords[:, :, :, 1]))

            else:
                ref_warped = self.warp_ref_image(inv_depths, ref_image, K, k,
                                                 p, ref_K, ref_k, ref_p, pose)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)
            if self.mask_ego:
                for i in range(self.n):
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])
            else:
                for i in range(self.n):
                    photometric_losses[i].append(photometric_loss[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                image,
                context,
                inv_depths,
                poses,
                path_to_ego_mask,
                path_to_ego_mask_context,
                K,
                ref_K,
                extrinsics,
                ref_extrinsics,
                context_type,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        inv_depths2 = []
        for k in range(len(inv_depths)):

            Btmp, C, H, W = inv_depths[k].shape
            depths2 = torch.zeros_like(inv_depths[k])
            for i in range(H):
                for j in range(W):
                    depths2[
                        0, 0, i,
                        j] = 2.0  #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20

            inv_depths2.append(depth2inv(depths2))

        if is_list(context_type[0][0]):
            n_context = len(context_type[0])
        else:
            n_context = len(context_type)

        #n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        if is_list(path_to_ego_mask[0]):
            H_full, W_full = np.load(path_to_ego_mask[0][0]).shape
        else:
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                if is_list(path_to_ego_mask[b]):
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b][0])).float()
                else:
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    if is_list(path_to_ego_mask_context[0][0]):
                        paths_context_ego = [
                            p[i_context][0] for p in path_to_ego_mask_context
                        ]
                    else:
                        paths_context_ego = path_to_ego_mask_context[i_context]
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(paths_context_ego[b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            Btmp, C, H, W = images[i].shape
            if W < W_full:
                #inv_scale_factor = int(W_full / W)
                #print(W_full / W)
                #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(Btmp, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_context in range(n_context):
                    #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            _, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                ref_ego_mask_tensor[i_context] = interpolate_image(
                    ref_ego_mask_tensor[i_context],
                    shape=(Btmp, 1, H, W),
                    mode='nearest',
                    align_corners=None)

        print(ref_extrinsics)
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            print(ref_extrinsics[j])
            ref_context_type = [c[j][0] for c in context_type] if is_list(
                context_type[0][0]) else context_type[j]
            print(ref_context_type)
            print(pose.mat)
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image(
                inv_depths2, ref_image, ref_ego_mask_tensor[j], K,
                ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type)
            print(pose.mat)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            tt = str(int(time.time() % 10000))
            for i in range(self.n):
                B, C, H, W = images[i].shape
                for b in range(B):
                    orig_PIL_0 = torch.transpose(
                        (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    orig_PIL = torch.transpose(
                        (ref_image[b, :, :, :] *
                         ref_ego_mask_tensor[j][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    warped_PIL_0 = torch.transpose(
                        (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4),
                        1, 4).squeeze().detach().cpu().numpy()
                    warped_PIL = torch.transpose(
                        (ref_warped[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL_0 = torch.transpose(
                        (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL = torch.transpose(
                        (images[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()

                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png',
                        orig_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png',
                        orig_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png',
                        warped_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png',
                        warped_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png',
                        target_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png',
                        target_PIL * 255)

            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ref_ego_mask_tensors[j][i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self,
                inv_depths,
                gt_inv_depth,
                path_to_ego_mask,
                return_logs=False,
                progress=0.0):
        """
        Calculates training supervised loss.

        Parameters
        ----------
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        gt_inv_depth : torch.Tensor [B,1,H,W]
            Ground-truth depth map for the original image
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Match predicted scales for ground-truth
        gt_inv_depths = match_scales(gt_inv_depth,
                                     inv_depths,
                                     self.n,
                                     mode='nearest',
                                     align_corners=None)

        if self.mask_ego:
            device = gt_inv_depth.get_device()
            B = len(path_to_ego_mask)
            H_full = 800
            W_full = 1280
            ego_mask_tensor = torch.zeros(B, 1, H_full, W_full).to(device)
            for b in range(B):
                ego_mask_tensor[b, 0] = torch.from_numpy(
                    np.load(path_to_ego_mask[b])).float()
            ego_mask_tensors = []  # = torch.zeros(B, 1, 800, 1280)
            for i in range(self.n):
                B, C, H, W = inv_depths[i].shape
                if W < W_full:
                    ego_mask_tensors.append(
                        interpolate_image(ego_mask_tensor,
                                          shape=(B, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
                else:
                    ego_mask_tensors.append(ego_mask_tensor)

        # Calculate and store supervised loss
        if self.mask_ego:
            loss = self.calculate_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(gt_inv_depths, ego_mask_tensors)])
        else:
            loss = self.calculate_loss(inv_depths, gt_inv_depths)
        self.add_metric('supervised_loss', loss)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def warp_ref_image_tensor(self, inv_depths, ref_image, ref_tensor,
                              path_to_theta_lut, path_to_ego_mask, poly_coeffs,
                              principal_point, scale_factors,
                              ref_path_to_theta_lut, ref_path_to_ego_mask,
                              ref_poly_coeffs, ref_principal_point,
                              ref_scale_factors, same_timestamp_as_origin,
                              pose_matrix_context, pose):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for b in range(B):
            if same_timestamp_as_origin[b]:
                pose.mat[b, :, :] = pose_matrix_context[b, :, :]
        # pose_matrix = torch.zeros(B, 4, 4)
        # for b in range(B):
        #     if not same_timestamp_as_origin[b]:
        #         pose_matrix[b, :, :] = pose.mat[b, :, :]
        #     else:
        #         pose_matrix[b, :, :] = pose_matrix_context[b, :, :]
        #pose_matrix = Pose(pose_matrix)
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(
                CameraFisheye(path_to_theta_lut=path_to_theta_lut,
                              path_to_ego_mask=path_to_ego_mask,
                              poly_coeffs=poly_coeffs.float(),
                              principal_point=principal_point.float(),
                              scale_factors=scale_factors.float()).scaled(
                                  scale_factor).to(device))
            ref_cams.append(
                CameraFisheye(path_to_theta_lut=ref_path_to_theta_lut,
                              path_to_ego_mask=ref_path_to_ego_mask,
                              poly_coeffs=ref_poly_coeffs.float(),
                              principal_point=ref_principal_point.float(),
                              scale_factors=ref_scale_factors.float(),
                              Tcw=pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis(ref_images[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode,
                           mode='bilinear') for i in range(self.n)
        ]
        ref_tensors = match_scales(ref_tensor,
                                   inv_depths,
                                   self.n,
                                   mode='nearest',
                                   align_corners=None)
        ref_tensors_warped = [
            view_synthesis(ref_tensors[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode,
                           mode='nearest',
                           align_corners=None) for i in range(self.n)
        ]
        #print(ref_tensors[0][:, :, ::40, ::40])
        #print(ref_tensors_warped[0][:, :, ::40, ::40])
        # Return warped reference image
        return ref_warped, ref_tensors_warped
    def forward(
            self,
            image,
            context,
            inv_depths,
            ref_inv_depths,
            path_to_theta_lut,
            path_to_ego_mask,
            poly_coeffs,
            principal_point,
            scale_factors,
            ref_path_to_theta_lut,
            ref_path_to_ego_mask,
            ref_poly_coeffs,
            ref_principal_point,
            ref_scale_factors,  # ALL LISTS !!!
            same_timestep_as_origin,
            pose_matrix_context,
            poses,
            return_logs=False,
            progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        n_context = len(context)

        #if self.mask_ego:
        device = image.get_device()

        B = len(path_to_ego_mask)

        ego_mask_tensor = torch.zeros(B, 1, 800, 1280).to(device)
        ref_ego_mask_tensor = [
        ]  #[torch.zeros(B, 1, 800, 1280).to(device)] * n_context
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(torch.zeros(B, 1, 800, 1280).to(device))
        for b in range(B):
            ego_mask_tensor[b, 0] = torch.from_numpy(
                np.load(path_to_ego_mask[b])).float()
            for i_context in range(n_context):
                ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                    np.load(ref_path_to_ego_mask[i_context][b])).float()

        ego_mask_tensors = []  # = torch.zeros(B, 1, 800, 1280)
        ref_ego_mask_tensors = [
        ]  #[[]] *  n_context  # = torch.zeros(B, 1, 800, 1280)
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            B, C, H, W = images[i].shape
            if W < 1280:
                inv_scale_factor = int(1280 / W)
                ego_mask_tensors.append(-nn.MaxPool2d(
                    inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)
                        (-ref_ego_mask_tensor[i_context]))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        ref_ego_mask_tensor[i_context])
            # ego_mask_tensor = ego_mask_tensor.to(device)

        for i_context in range(n_context):
            B, C, H, W = context[i_context].shape
            if W < 1280:
                inv_scale_factor = int(1280 / W)
                ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(
                    inv_scale_factor,
                    inv_scale_factor)(-ref_ego_mask_tensor[i_context])

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            ref_warped, ref_ego_mask_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped \
                = self.warp_ref_image_depth(inv_depths, ref_image, ref_ego_mask_tensor[j],# * ref_ego_mask_tensor[j],
                                            ref_inv_depths[j],#[ref_inv_depths[j][i] * ref_ego_mask_tensors[j][i] for i in range(self.n)],
                                            path_to_theta_lut,        path_to_ego_mask,        poly_coeffs,        principal_point,        scale_factors,
                                            ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j], ref_principal_point[j], ref_scale_factors[j],
                                            same_timestep_as_origin[j],
                                            pose_matrix_context[j],
                                            pose)
            coeff_margin_occlusion = 1.5
            coeff_delta_occlusion = 1.5

            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            if self.occ_disocc_handling == 'masks':
                without_occlusion_masks1a = [
                    (inv_depths_wrt_ref_cam[i] <=
                     coeff_margin_occlusion * ref_inv_depths_warped[i])
                    for i in range(self.n)
                ]
                without_occlusion_masks1b = [
                    (ref_inv_depths_warped[i] <=
                     coeff_margin_occlusion * inv_depths_wrt_ref_cam[i])
                    for i in range(self.n)
                ]

                without_occlusion_masks2a = [
                    (inv2depth(
                        inv_depths_wrt_ref_cam[i]) <= coeff_delta_occlusion +
                     inv2depth(ref_inv_depths_warped[i]))
                    for i in range(self.n)
                ]
                without_occlusion_masks2b = [
                    (inv2depth(
                        ref_inv_depths_warped[i]) <= coeff_delta_occlusion +
                     inv2depth(inv_depths_wrt_ref_cam[i]))
                    for i in range(self.n)
                ]

                without_occlusion_masks = [
                    without_occlusion_masks1a[i] + without_occlusion_masks2b[i]
                    for i in range(self.n)
                ]
                without_disocclusion_masks = [
                    without_occlusion_masks1b[i] + without_occlusion_masks2a[i]
                    for i in range(self.n)
                ]

                if self.mask_occlusion and self.mask_disocclusion:
                    valid_pixels_occ = [
                        (without_occlusion_masks[i] *
                         without_disocclusion_masks[i]).float()
                        for i in range(self.n)
                    ]
                elif self.mask_occlusion:
                    valid_pixels_occ = [
                        without_occlusion_masks[i].float()
                        for i in range(self.n)
                    ]
                elif self.mask_disocclusion:
                    valid_pixels_occ = [
                        without_disocclusion_masks[i].float()
                        for i in range(self.n)
                    ]
                else:
                    valid_pixels_occ = [
                        torch.ones_like(inv_depths[i]) for i in range(self.n)
                    ]

                B = len(path_to_ego_mask)
                for b in range(B):
                    if (same_timestep_as_origin[j][b]
                            and not self.mask_spatial_context) or (
                                not same_timestep_as_origin[j][b]
                                and not self.mask_temporal_context):
                        valid_pixels_occ[i][b, :, :, :] = 1.0

                for i in range(self.n):
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i] * valid_pixels_occ[i])

            elif self.occ_disocc_handling == 'consistency_tensor':
                consistency_tensors = [
                    self.depth_consistency_weight * inv_depths_wrt_ref_cam[i] *
                    torch.abs(
                        inv2depth(inv_depths_wrt_ref_cam[i]) -
                        inv2depth(ref_inv_depths_warped[i]))
                    for i in range(self.n)
                ]
                B = len(path_to_ego_mask)
                for b in range(B):
                    if (same_timestep_as_origin[j][b]
                            and not self.mask_spatial_context) or (
                                not same_timestep_as_origin[j][b]
                                and not self.mask_temporal_context):
                        consistency_tensors[i][b, :, :, :] = 0.0

                for i in range(self.n):
                    photometric_losses[i].append(
                        (photometric_loss[i] + consistency_tensors[i]) *
                        ego_mask_tensors[i] * ref_ego_mask_tensors_warped[i])

            else:
                for i in range(self.n):
                    photometric_losses[i].append(
                        photometric_loss[i] * ego_mask_tensors[i] *
                        ref_ego_mask_tensors_warped[i])

            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                #if self.mask_ego:
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ref_ego_mask_tensors[j][i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            #if self.mask_ego:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
            # else:
            #     loss += self.calc_smoothness_loss(inv_depths, images)
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def warp_ref_image(self, inv_depths, ref_image, ref_tensor, K, ref_K, pose,
                       ref_extrinsics, ref_context_type):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        for b in range(B):
            if ref_context_type[b] == 'left' or ref_context_type[b] == 'right':
                pose.mat[b, :, :] = ref_extrinsics[b, :, :]
                #pose.mat[b,:3,3]=0
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(Camera(K=K.float()).scaled(scale_factor).to(device))
            ref_cams.append(
                Camera(K=ref_K.float(),
                       Tcw=pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)
        ref_warped = [
            view_synthesis(ref_images[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode)
            for i in range(self.n)
        ]
        # Return warped reference image

        ref_tensors = match_scales(ref_tensor,
                                   inv_depths,
                                   self.n,
                                   mode='nearest',
                                   align_corners=None)
        ref_tensors_warped = [
            view_synthesis(ref_tensors[i],
                           depths[i],
                           ref_cams[i],
                           cams[i],
                           padding_mode=self.padding_mode,
                           mode='nearest') for i in range(self.n)
        ]

        return ref_warped, ref_tensors_warped
    def warp_ref(self, inv_depths, ref_image, ref_tensor, ref_inv_depths,
                 path_to_theta_lut, path_to_ego_mask, poly_coeffs, principal_point, scale_factors,
                 ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point,
                 ref_scale_factors,
                 same_timestamp_as_origin,
                 pose_matrix_context,
                 pose,
                 warp_ref_depth,
                 allow_context_rotation):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        if allow_context_rotation:
            pose_matrix = torch.zeros(B, 4, 4)
            for b in range(B):
                if same_timestamp_as_origin[b]:
                    pose_matrix[b, :3, 3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, 3]
                    pose_matrix[b, 3, 3] = 1
                    pose_matrix[b, :3, :3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, :3]
                else:
                    pose_matrix[b, :, :] = pose.mat[b, :, :]
        else:
            for b in range(B):
                if same_timestamp_as_origin[b]:
                    pose.mat[b, :, :] = pose_matrix_context[b, :, :]
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(CameraFisheye(path_to_theta_lut=path_to_theta_lut,
                                      path_to_ego_mask=path_to_ego_mask,
                                      poly_coeffs=poly_coeffs.float(),
                                      principal_point=principal_point.float(),
                                      scale_factors=scale_factors.float()).scaled(scale_factor).to(device))
            ref_cams.append(CameraFisheye(path_to_theta_lut=ref_path_to_theta_lut,
                                          path_to_ego_mask=ref_path_to_ego_mask,
                                          poly_coeffs=ref_poly_coeffs.float(),
                                          principal_point=ref_principal_point.float(),
                                          scale_factors=ref_scale_factors.float(),
                                          Tcw=pose_matrix if allow_context_rotation else pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)

        if warp_ref_depth:
            ref_depths = [inv2depth(ref_inv_depths[i]) for i in range(self.n)]
            ref_warped = []
            depths_wrt_ref_cam = []
            ref_depths_warped = []
            for i in range(self.n):
                view_i, depth_wrt_ref_cam_i, ref_depth_warped_i \
                    = view_depth_synthesis2(ref_images[i], depths[i], ref_depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode)
                ref_warped.append(view_i)
                depths_wrt_ref_cam.append(depth_wrt_ref_cam_i)
                ref_depths_warped.append(ref_depth_warped_i)
            inv_depths_wrt_ref_cam = [depth2inv(depths_wrt_ref_cam[i]) for i in range(self.n)]
            ref_inv_depths_warped = [depth2inv(ref_depths_warped[i]) for i in range(self.n)]
        else:
            ref_warped = [view_synthesis(ref_images[i], depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode) for i in range(self.n)]
            inv_depths_wrt_ref_cam = None
            ref_inv_depths_warped = None

        ref_tensors = match_scales(ref_tensor, inv_depths, self.n, mode='nearest', align_corners=None)
        ref_tensors_warped = [view_synthesis(
            ref_tensors[i], depths[i], ref_cams[i], cams[i],
            padding_mode=self.padding_mode, mode='nearest', align_corners=None) for i in range(self.n)]

        # Return warped reference image
        return ref_warped, ref_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped
    def forward(self, gt_depth, depths,
                path_to_theta_lut,     path_to_ego_mask,     poly_coeffs,     principal_point,     scale_factors,
                ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point, ref_scale_factors, # ALL LISTS !!!
                poses, return_logs=False, progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        reprojected_losses = [[] for _ in range(self.n)]
        gt_depths = match_scales(gt_depth, depths, self.n)
        if self.mask_ego:
            device = gt_depth.get_device()
            B = len(path_to_ego_mask)
            ego_mask_tensor     = torch.zeros(B, 1, 800, 1280)
            for b in range(B):
                ego_mask_tensor[b, 0]     = torch.from_numpy(np.load(path_to_ego_mask[b])).float()
            ego_mask_tensors     = []  # = torch.zeros(B, 1, 800, 1280)
            for i in range(self.n):
                B, C, H, W = gt_depths[i].shape
                if W < 1280:
                    inv_scale_factor = int(1280 / W)
                    ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor).to(device))
                else:
                    ego_mask_tensors.append(ego_mask_tensor.to(device))

            gt_depths = [a * b for a, b in zip(gt_depths, ego_mask_tensors)]

        gt_depths_mask = [(gt_depths[i] > 0.).detach() for i in range(self.n)]

        for j, pose in enumerate(poses):
            # Calculate warped images
            target_pixels_gt_warped, target_pixels_warped \
                = self.warp_target_pixels(gt_depths, depths,
                                          path_to_theta_lut,        path_to_ego_mask,        poly_coeffs,        principal_point,        scale_factors,
                                          ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j], ref_principal_point[j], ref_scale_factors[j],
                                          pose)

            for i in range(self.n):
                X_gt = target_pixels_gt_warped[i][:, 0, :, :].unsqueeze(1)[gt_depths_mask[i]]
                Y_gt = target_pixels_gt_warped[i][:, 1, :, :].unsqueeze(1)[gt_depths_mask[i]]

                X = target_pixels_warped[i][:, 0, :, :].unsqueeze(1)[gt_depths_mask[i]]
                Y = target_pixels_warped[i][:, 1, :, :].unsqueeze(1)[gt_depths_mask[i]]

                #print(X.size()[0])

                if self.mask_out_of_bounds_reprojected:
                    inside_of_bounds_mask = torch.logical_not(((X_gt > 1) + (X_gt < -1) + (Y_gt > 1) + (Y_gt < -1) + (X > 1) + (X < -1) + (Y > 1) + (Y < -1))).detach()
                    X_gt = X_gt[inside_of_bounds_mask]
                    Y_gt = Y_gt[inside_of_bounds_mask]
                    X    = X[inside_of_bounds_mask]
                    Y    = Y[inside_of_bounds_mask]

                    #print(X.size()[0])

                pixels_gt = torch.stack([X_gt, Y_gt]).view(2, -1).transpose(0, 1) # [N, 2]
                pixels    = torch.stack([   X,    Y]).view(2, -1).transpose(0, 1) # [N, 2]

                reprojected_loss = torch.mean(torch.sqrt(torch.sum((pixels_gt-pixels)**2, axis=1)+1e-8))#torch.sqrt(torch.mean((pixels_gt - pixels) ** 2))
                #print(torch.mean(torch.sqrt(torch.sum((pixels_gt-pixels)**2, axis=1))))
                #print(torch.sqrt(torch.mean((pixels_gt - pixels) ** 2)))
                reprojected_losses[i].append(reprojected_loss)

        loss = sum([sum([l.mean() for l in reprojected_losses[i]]) / len(reprojected_losses[i]) for i in range(self.n)]) / self.n
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
    def forward(self, image, context, inv_depths, ref_inv_depths,
                path_to_theta_lut,     path_to_ego_mask,     poly_coeffs,     principal_point,     scale_factors,
                ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point, ref_scale_factors, # ALL LISTS !!!
                same_timestep_as_origin,
                pose_matrix_context,
                poses, return_logs=False, progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)


        n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                ego_mask_tensor[b, 0] = torch.from_numpy(np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(np.load(ref_path_to_ego_mask[i_context][b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            B, C, H, W = images[i].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor, shape=(B, 1, H, W), mode='nearest', align_corners=None)
                    #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor)
                )
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(B, 1, H, W),
                                          mode='nearest',
                                          align_corners=None)
                        #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                    )
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            B, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                ref_ego_mask_tensor[i_context] = interpolate_image(ref_ego_mask_tensor[i_context],
                                                                   shape=(B, 1, H, W),
                                                                   mode='nearest',
                                                                   align_corners=None)
                #-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])

        B = len(path_to_ego_mask)

        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            ref_warped, ref_ego_mask_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped \
                = self.warp_ref(inv_depths, ref_image, ref_ego_mask_tensor[j], ref_inv_depths[j],
                                path_to_theta_lut, path_to_ego_mask, poly_coeffs, principal_point, scale_factors,
                                ref_path_to_theta_lut[j], ref_path_to_ego_mask[j], ref_poly_coeffs[j],
                                ref_principal_point[j], ref_scale_factors[j],
                                same_timestep_as_origin[j], pose_matrix_context[j],
                                pose,
                                warp_ref_depth=self.use_ref_depth,
                                allow_context_rotation=self.allow_context_rotation)

            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            if (self.mask_occlusion or self.mask_disocclusion) and (self.mask_spatial_context or self.mask_temporal_context):
                no_occlusion_masks    = [(inv_depths_wrt_ref_cam[i] <= self.mult_margin_occlusion * ref_inv_depths_warped[i])
                                         + (inv2depth(ref_inv_depths_warped[i]) <= self.add_margin_occlusion  + inv2depth(inv_depths_wrt_ref_cam[i]))
                                         for i in range(self.n)] # boolean OR
                no_disocclusion_masks = [(ref_inv_depths_warped[i] <= self.mult_margin_occlusion * inv_depths_wrt_ref_cam[i])
                                         + (inv2depth(inv_depths_wrt_ref_cam[i]) <= self.add_margin_occlusion  + inv2depth(ref_inv_depths_warped[i]))
                                         for i in range(self.n)] # boolean OR
                if self.mask_occlusion and self.mask_disocclusion:
                    valid_pixels_occ = [(no_occlusion_masks[i] * no_disocclusion_masks[i]).float() for i in range(self.n)]
                elif self.mask_occlusion:
                    valid_pixels_occ = [no_occlusion_masks[i].float() for i in range(self.n)]
                elif self.mask_disocclusion:
                    valid_pixels_occ = [no_disocclusion_masks[i].float() for i in range(self.n)]
                for b in range(B):
                    if (same_timestep_as_origin[j][b] and not self.mask_spatial_context) or (not same_timestep_as_origin[j][b] and not self.mask_temporal_context):
                        valid_pixels_occ[i][b, :, :, :] = 1.0
            else:
                valid_pixels_occ = [torch.ones_like(inv_depths[i]) for i in range(self.n)]

            if self.depth_consistency_weight > 0.0:
                consistency_tensors_1 = [self.depth_consistency_weight * inv_depths_wrt_ref_cam[i] * torch.abs(inv2depth(inv_depths_wrt_ref_cam[i]) - inv2depth(ref_inv_depths_warped[i])) for i in range(self.n)]
                consistency_tensors_2 = [self.depth_consistency_weight * ref_inv_depths_warped[i]  * torch.abs(inv2depth(inv_depths_wrt_ref_cam[i]) - inv2depth(ref_inv_depths_warped[i])) for i in range(self.n)]
                consistency_tensors = [torch.cat([consistency_tensors_1[i], consistency_tensors_2[i]],1).min(1, True)[0] for i in range(self.n)]
            else:
                consistency_tensors = [torch.zeros_like(inv_depths[i]) for i in range(self.n)]

            for i in range(self.n):
                photometric_losses[i].append((photometric_loss[i] + consistency_tensors[i]) * ego_mask_tensors[i] * ref_ego_mask_tensors_warped[i] * valid_pixels_occ[i])

            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] * ego_mask_tensors[i] * ref_ego_mask_tensors[j][i])

        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss([a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                                              [a * b for a, b in zip(images,     ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }