def forward(self, batch, return_logs=False, progress=0.0, **kwargs):
        """
        Processes a batch.

        Parameters
        ----------
        batch : dict
            Input batch
        return_logs : bool
            True if logs are stored
        progress :
            Training progress percentage

        Returns
        -------
        output : dict
            Dictionary containing a "loss" scalar and different metrics and predictions
            for logging and downstream usage.
        """
        if not self.training:
            # If not training, no need for self-supervised loss
            return SfmModel.forward(self, batch, return_logs=return_logs, **kwargs)
        else:
            if self.supervised_loss_weight == 1.:
                # If no self-supervision, no need to calculate loss
                self_sup_output = SfmModel.forward(self, batch, return_logs=return_logs, **kwargs)
                loss = torch.tensor([0.]).type_as(batch['rgb'])
            else:
                # Otherwise, calculate and weight self-supervised loss
                self_sup_output = SelfSupModel.forward(
                    self, batch, return_logs=return_logs, progress=progress, **kwargs)
                loss = (1.0 - self.supervised_loss_weight) * self_sup_output['loss']
            # Calculate and weight supervised loss
            sup_output = self.supervised_loss(
                self_sup_output['inv_depths'], depth2inv(batch['depth']),
                return_logs=return_logs, progress=progress)
            loss += self.supervised_loss_weight * sup_output['loss']
            if 'inv_depths_rgbd' in self_sup_output:
                sup_output2 = self.supervised_loss(
                    self_sup_output['inv_depths_rgbd'], depth2inv(batch['depth']),
                    return_logs=return_logs, progress=progress)
                loss += self.weight_rgbd * self.supervised_loss_weight * sup_output2['loss']
                if 'depth_loss' in self_sup_output:
                    loss += self_sup_output['depth_loss']
            # Merge and return outputs
            return {
                'loss': loss,
                **merge_outputs(self_sup_output, sup_output),
            }
Esempio n. 2
0
    def forward(self, batch, return_logs=False, progress=0.0):
        """
        Processes a batch.

        Parameters
        ----------
        batch : dict
            Input batch
        return_logs : bool
            True if logs are stored
        progress :
            Training progress percentage

        Returns
        -------
        output : dict
            Dictionary containing a "loss" scalar and different metrics and predictions
            for logging and downstream usage.
        """
        if not self.training:
            # If not training, no need for self-supervised loss
            return SfmModel.forward(self, batch)
        else:
            # Calculate predicted depth and pose output
            output = super().forward(batch, return_logs=return_logs)

            # Introduce poses ground_truth
            poses_gt = [[], []]
            poses_gt[0], poses_gt[1] = torch.zeros((6, 4, 4)), torch.zeros((6, 4, 4))
            for i in range(6):
                poses_gt[0][i] = batch['pose_context'][0][i].inverse() @ batch['pose'][i]
                poses_gt[1][i] = batch['pose_context'][1][i].inverse() @ batch['pose'][i]
            poses_gt = [Pose(poses_gt[0]), Pose(poses_gt[1])]

            multiview_loss = self.multiview_photometric_loss(
                batch['rgb_original'], batch['rgb_context_original'],
                output['inv_depths'], poses_gt, batch['intrinsics'], batch['extrinsics'],
                return_logs=return_logs, progress=progress)

            # loss = multiview_loss['loss']
            loss = 0.
            # Calculate supervised loss
            supervision_loss = self.supervised_loss(output['inv_depths'], depth2inv(batch['depth']),
                                                    return_logs=return_logs, progress=progress)
            loss += self.supervised_loss_weight * supervision_loss['loss']

            # Return loss and metrics
            return {
                'loss': loss,
                **merge_outputs(merge_outputs(multiview_loss, supervision_loss), output)
            }
    def forward(self,
                image,
                context,
                inv_depths,
                poses,
                path_to_ego_mask,
                path_to_ego_mask_context,
                K,
                ref_K,
                extrinsics,
                ref_extrinsics,
                context_type,
                return_logs=False,
                progress=0.0):
        """
        Calculates training photometric loss.

        Parameters
        ----------
        image : torch.Tensor [B,3,H,W]
            Original image
        context : list of torch.Tensor [B,3,H,W]
            Context containing a list of reference images
        inv_depths : list of torch.Tensor [B,1,H,W]
            Predicted depth maps for the original image, in all scales
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        poses : list of Pose
            Camera transformation between original and context
        return_logs : bool
            True if logs are saved for visualization
        progress : float
            Training percentage

        Returns
        -------
        losses_and_metrics : dict
            Output dictionary
        """
        # If using progressive scaling
        self.n = self.progressive_scaling(progress)
        # Loop over all reference images
        photometric_losses = [[] for _ in range(self.n)]
        images = match_scales(image, inv_depths, self.n)

        inv_depths2 = []
        for k in range(len(inv_depths)):

            Btmp, C, H, W = inv_depths[k].shape
            depths2 = torch.zeros_like(inv_depths[k])
            for i in range(H):
                for j in range(W):
                    depths2[
                        0, 0, i,
                        j] = 2.0  #(2/H)**4 * (2/W)**4 * i**2 * (H - i)**2 * j**2 * (W - j)**2 * 20

            inv_depths2.append(depth2inv(depths2))

        if is_list(context_type[0][0]):
            n_context = len(context_type[0])
        else:
            n_context = len(context_type)

        #n_context = len(context)
        device = image.get_device()
        B = len(path_to_ego_mask)
        if is_list(path_to_ego_mask[0]):
            H_full, W_full = np.load(path_to_ego_mask[0][0]).shape
        else:
            H_full, W_full = np.load(path_to_ego_mask[0]).shape

        # getting ego masks for target and source cameras
        # fullsize mask
        ego_mask_tensor = torch.ones(B, 1, H_full, W_full).to(device)
        ref_ego_mask_tensor = []
        for i_context in range(n_context):
            ref_ego_mask_tensor.append(
                torch.ones(B, 1, H_full, W_full).to(device))
        for b in range(B):
            if self.mask_ego:
                if is_list(path_to_ego_mask[b]):
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b][0])).float()
                else:
                    ego_mask_tensor[b, 0] = torch.from_numpy(
                        np.load(path_to_ego_mask[b])).float()
                for i_context in range(n_context):
                    if is_list(path_to_ego_mask_context[0][0]):
                        paths_context_ego = [
                            p[i_context][0] for p in path_to_ego_mask_context
                        ]
                    else:
                        paths_context_ego = path_to_ego_mask_context[i_context]
                    ref_ego_mask_tensor[i_context][b, 0] = torch.from_numpy(
                        np.load(paths_context_ego[b])).float()
        # resized masks
        ego_mask_tensors = []
        ref_ego_mask_tensors = []
        for i_context in range(n_context):
            ref_ego_mask_tensors.append([])
        for i in range(self.n):
            Btmp, C, H, W = images[i].shape
            if W < W_full:
                #inv_scale_factor = int(W_full / W)
                #print(W_full / W)
                #ego_mask_tensors.append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ego_mask_tensor))
                ego_mask_tensors.append(
                    interpolate_image(ego_mask_tensor,
                                      shape=(Btmp, 1, H, W),
                                      mode='nearest',
                                      align_corners=None))
                for i_context in range(n_context):
                    #ref_ego_mask_tensors[i_context].append(-nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context]))
                    ref_ego_mask_tensors[i_context].append(
                        interpolate_image(ref_ego_mask_tensor[i_context],
                                          shape=(Btmp, 1, H, W),
                                          mode='nearest',
                                          align_corners=None))
            else:
                ego_mask_tensors.append(ego_mask_tensor)
                for i_context in range(n_context):
                    ref_ego_mask_tensors[i_context].append(
                        ref_ego_mask_tensor[i_context])
        for i_context in range(n_context):
            _, C, H, W = context[i_context].shape
            if W < W_full:
                inv_scale_factor = int(W_full / W)
                #ref_ego_mask_tensor[i_context] = -nn.MaxPool2d(inv_scale_factor, inv_scale_factor)(-ref_ego_mask_tensor[i_context])
                ref_ego_mask_tensor[i_context] = interpolate_image(
                    ref_ego_mask_tensor[i_context],
                    shape=(Btmp, 1, H, W),
                    mode='nearest',
                    align_corners=None)

        print(ref_extrinsics)
        for j, (ref_image, pose) in enumerate(zip(context, poses)):
            print(ref_extrinsics[j])
            ref_context_type = [c[j][0] for c in context_type] if is_list(
                context_type[0][0]) else context_type[j]
            print(ref_context_type)
            print(pose.mat)
            # Calculate warped images
            ref_warped, ref_ego_mask_tensors_warped = self.warp_ref_image(
                inv_depths2, ref_image, ref_ego_mask_tensor[j], K,
                ref_K[:, j, :, :], pose, ref_extrinsics[j], ref_context_type)
            print(pose.mat)
            # Calculate and store image loss
            photometric_loss = self.calc_photometric_loss(ref_warped, images)

            tt = str(int(time.time() % 10000))
            for i in range(self.n):
                B, C, H, W = images[i].shape
                for b in range(B):
                    orig_PIL_0 = torch.transpose(
                        (ref_image[b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    orig_PIL = torch.transpose(
                        (ref_image[b, :, :, :] *
                         ref_ego_mask_tensor[j][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    warped_PIL_0 = torch.transpose(
                        (ref_warped[i][b, :, :, :]).unsqueeze(0).unsqueeze(4),
                        1, 4).squeeze().detach().cpu().numpy()
                    warped_PIL = torch.transpose(
                        (ref_warped[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL_0 = torch.transpose(
                        (images[i][b, :, :, :]).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()
                    target_PIL = torch.transpose(
                        (images[i][b, :, :, :] *
                         ego_mask_tensors[i][b, :, :, :]
                         ).unsqueeze(0).unsqueeze(4), 1,
                        4).squeeze().detach().cpu().numpy()

                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL_0.png',
                        orig_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_orig_PIL.png',
                        orig_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL_0.png',
                        warped_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_warped_PIL.png',
                        warped_PIL * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL_0.png',
                        target_PIL_0 * 255)
                    cv2.imwrite(
                        '/home/users/vbelissen/test' + '_' + str(j) + '_' +
                        tt + '_' + str(b) + '_' + str(i) + '_target_PIL.png',
                        target_PIL * 255)

            for i in range(self.n):
                photometric_losses[i].append(photometric_loss[i] *
                                             ego_mask_tensors[i] *
                                             ref_ego_mask_tensors_warped[i])
            # If using automask
            if self.automask_loss:
                # Calculate and store unwarped image loss
                ref_images = match_scales(ref_image, inv_depths, self.n)
                unwarped_image_loss = self.calc_photometric_loss(
                    ref_images, images)
                for i in range(self.n):
                    photometric_losses[i].append(unwarped_image_loss[i] *
                                                 ego_mask_tensors[i] *
                                                 ref_ego_mask_tensors[j][i])
        # Calculate reduced photometric loss
        loss = self.nonzero_reduce_photometric_loss(photometric_losses)
        # Include smoothness loss if requested
        if self.smooth_loss_weight > 0.0:
            loss += self.calc_smoothness_loss(
                [a * b for a, b in zip(inv_depths, ego_mask_tensors)],
                [a * b for a, b in zip(images, ego_mask_tensors)])
        # Return losses and metrics
        return {
            'loss': loss.unsqueeze(0),
            'metrics': self.metrics,
        }
Esempio n. 4
0
                                       mode='bilinear',
                                       padding_mode='zeros',
                                       align_corners=True)

warped_right_front_PIL = torch.transpose(warped_right_front.unsqueeze(4), 1,
                                         4).squeeze().cpu().numpy()
cv2.imwrite('/home/users/vbelissen/test' + tt + '_right_front.png',
            warped_right_front_PIL[:, :, ::-1])

simulated_depth_right_to_front = funct.grid_sample(simulated_depth_right,
                                                   ref_coords_right,
                                                   mode='bilinear',
                                                   padding_mode='zeros',
                                                   align_corners=True)

viz_pred_inv_depth_front = viz_inv_depth(depth2inv(simulated_depth)[0],
                                         normalizer=1.0) * 255
viz_pred_inv_depth_right = viz_inv_depth(depth2inv(simulated_depth_right)[0],
                                         normalizer=1.0) * 255
viz_pred_inv_depth_right_to_front = viz_inv_depth(
    depth2inv(simulated_depth_right_to_front)[0], normalizer=1.0) * 255

world_points_right_in_front_coords = cam_front.Tcw @ world_points_right
simulated_depth_right_in_front_coords = torch.norm(
    world_points_right_in_front_coords, dim=1, keepdim=True)
simulated_depth_right_in_front_coords[~not_masked_right] = 0
simulated_depth_right_to_front_in_front_coords = funct.grid_sample(
    simulated_depth_right_in_front_coords,
    ref_coords_right,
    mode='bilinear',
    padding_mode='zeros',
    def warp_ref(self, inv_depths, ref_image, ref_tensor, ref_inv_depths,
                 path_to_theta_lut, path_to_ego_mask, poly_coeffs, principal_point, scale_factors,
                 ref_path_to_theta_lut, ref_path_to_ego_mask, ref_poly_coeffs, ref_principal_point,
                 ref_scale_factors,
                 same_timestamp_as_origin,
                 pose_matrix_context,
                 pose,
                 warp_ref_depth,
                 allow_context_rotation):
        """
        Warps a reference image to produce a reconstruction of the original one.

        Parameters
        ----------
        inv_depths : torch.Tensor [B,1,H,W]
            Inverse depth map of the original image
        ref_image : torch.Tensor [B,3,H,W]
            Reference RGB image
        K : torch.Tensor [B,3,3]
            Original camera intrinsics
        ref_K : torch.Tensor [B,3,3]
            Reference camera intrinsics
        pose : Pose
            Original -> Reference camera transformation

        Returns
        -------
        ref_warped : torch.Tensor [B,3,H,W]
            Warped reference image (reconstructing the original one)
        """
        B, _, H, W = ref_image.shape
        device = ref_image.get_device()
        # Generate cameras for all scales
        cams, ref_cams = [], []
        if allow_context_rotation:
            pose_matrix = torch.zeros(B, 4, 4)
            for b in range(B):
                if same_timestamp_as_origin[b]:
                    pose_matrix[b, :3, 3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, 3]
                    pose_matrix[b, 3, 3] = 1
                    pose_matrix[b, :3, :3] = pose.mat[b, :3, :3] @ pose_matrix_context[b, :3, :3]
                else:
                    pose_matrix[b, :, :] = pose.mat[b, :, :]
        else:
            for b in range(B):
                if same_timestamp_as_origin[b]:
                    pose.mat[b, :, :] = pose_matrix_context[b, :, :]
        for i in range(self.n):
            _, _, DH, DW = inv_depths[i].shape
            scale_factor = DW / float(W)
            cams.append(CameraFisheye(path_to_theta_lut=path_to_theta_lut,
                                      path_to_ego_mask=path_to_ego_mask,
                                      poly_coeffs=poly_coeffs.float(),
                                      principal_point=principal_point.float(),
                                      scale_factors=scale_factors.float()).scaled(scale_factor).to(device))
            ref_cams.append(CameraFisheye(path_to_theta_lut=ref_path_to_theta_lut,
                                          path_to_ego_mask=ref_path_to_ego_mask,
                                          poly_coeffs=ref_poly_coeffs.float(),
                                          principal_point=ref_principal_point.float(),
                                          scale_factors=ref_scale_factors.float(),
                                          Tcw=pose_matrix if allow_context_rotation else pose).scaled(scale_factor).to(device))
        # View synthesis
        depths = [inv2depth(inv_depths[i]) for i in range(self.n)]
        ref_images = match_scales(ref_image, inv_depths, self.n)

        if warp_ref_depth:
            ref_depths = [inv2depth(ref_inv_depths[i]) for i in range(self.n)]
            ref_warped = []
            depths_wrt_ref_cam = []
            ref_depths_warped = []
            for i in range(self.n):
                view_i, depth_wrt_ref_cam_i, ref_depth_warped_i \
                    = view_depth_synthesis2(ref_images[i], depths[i], ref_depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode)
                ref_warped.append(view_i)
                depths_wrt_ref_cam.append(depth_wrt_ref_cam_i)
                ref_depths_warped.append(ref_depth_warped_i)
            inv_depths_wrt_ref_cam = [depth2inv(depths_wrt_ref_cam[i]) for i in range(self.n)]
            ref_inv_depths_warped = [depth2inv(ref_depths_warped[i]) for i in range(self.n)]
        else:
            ref_warped = [view_synthesis(ref_images[i], depths[i], ref_cams[i], cams[i], padding_mode=self.padding_mode) for i in range(self.n)]
            inv_depths_wrt_ref_cam = None
            ref_inv_depths_warped = None

        ref_tensors = match_scales(ref_tensor, inv_depths, self.n, mode='nearest', align_corners=None)
        ref_tensors_warped = [view_synthesis(
            ref_tensors[i], depths[i], ref_cams[i], cams[i],
            padding_mode=self.padding_mode, mode='nearest', align_corners=None) for i in range(self.n)]

        # Return warped reference image
        return ref_warped, ref_tensors_warped, inv_depths_wrt_ref_cam, ref_inv_depths_warped
Esempio n. 6
0
def infer_optimal_calib(input_files, model_wrappers, image_shape):
    """
    Process a list of input files to infer correction in extrinsic calibration.
    Files should all correspond to the same car.
    Number of cameras is assumed to be 4 or 5.

    Parameters
    ----------
    input_file : list (number of cameras) of lists (number of files) of str
        Image file
    model_wrappers : nn.Module
        Model wrappers used for inference
    image_shape : Image shape
        Input image shape
    """
    N_files = len(input_files[0])
    N_cams = len(input_files)
    image_area = image_shape[0] * image_shape[1]

    camera_context_pairs = CAMERA_CONTEXT_PAIRS[N_cams]

    # Rotation will be optimized if not all cams are frozen
    optimize_rotation = (args.frozen_cams_rot != [i for i in range(N_cams)])

    # Rotation will be optimized if not all cams are frozen
    optimize_translation = (args.frozen_cams_trans !=
                            [i for i in range(N_cams)])

    calib_data = {}
    for i_cam in range(N_cams):
        base_folder_str = get_base_folder(input_files[i_cam][0])
        split_type_str = get_split_type(input_files[i_cam][0])
        seq_name_str = get_sequence_name(input_files[i_cam][0])
        camera_str = get_camera_name(input_files[i_cam][0])
        calib_data[camera_str] = read_raw_calib_files_camera_valeo_with_suffix(
            base_folder_str, split_type_str, seq_name_str, camera_str,
            args.calibrations_suffix)

    cams = []
    cams_untouched = []
    not_masked = []

    # Assume all images are from the same sequence (thus same cameras)
    for i_cam in range(N_cams):
        path_to_ego_mask = get_path_to_ego_mask(input_files[i_cam][0])
        poly_coeffs, principal_point, scale_factors, K, k, p = get_full_intrinsics(
            input_files[i_cam][0], calib_data)

        poly_coeffs_untouched = torch.from_numpy(poly_coeffs).unsqueeze(0)
        principal_point_untouched = torch.from_numpy(
            principal_point).unsqueeze(0)
        scale_factors_untouched = torch.from_numpy(scale_factors).unsqueeze(0)
        K_untouched = torch.from_numpy(K).unsqueeze(0)
        k_untouched = torch.from_numpy(k).unsqueeze(0)
        p_untouched = torch.from_numpy(p).unsqueeze(0)
        pose_matrix_untouched = torch.from_numpy(
            get_extrinsics_pose_matrix(input_files[i_cam][0],
                                       calib_data)).unsqueeze(0)
        pose_tensor_untouched = Pose(pose_matrix_untouched)
        camera_type_untouched = get_camera_type(input_files[i_cam][0],
                                                calib_data)
        camera_type_int_untouched = torch.tensor(
            [get_camera_type_int(camera_type_untouched)])

        cams.append(
            CameraMultifocal(poly_coeffs=poly_coeffs_untouched.float(),
                             principal_point=principal_point_untouched.float(),
                             scale_factors=scale_factors_untouched.float(),
                             K=K_untouched.float(),
                             k1=k_untouched[:, 0].float(),
                             k2=k_untouched[:, 1].float(),
                             k3=k_untouched[:, 2].float(),
                             p1=p_untouched[:, 0].float(),
                             p2=p_untouched[:, 1].float(),
                             camera_type=camera_type_int_untouched,
                             Tcw=pose_tensor_untouched))

        cams_untouched.append(
            CameraMultifocal(poly_coeffs=poly_coeffs_untouched.float(),
                             principal_point=principal_point_untouched.float(),
                             scale_factors=scale_factors_untouched.float(),
                             K=K_untouched.float(),
                             k1=k_untouched[:, 0].float(),
                             k2=k_untouched[:, 1].float(),
                             k3=k_untouched[:, 2].float(),
                             p1=p_untouched[:, 0].float(),
                             p2=p_untouched[:, 1].float(),
                             camera_type=camera_type_int_untouched,
                             Tcw=pose_tensor_untouched))
        if torch.cuda.is_available():
            cams[i_cam] = cams[i_cam].to('cuda:{}'.format(rank()))
            cams_untouched[i_cam] = cams_untouched[i_cam].to('cuda:{}'.format(
                rank()))

        ego_mask = np.load(path_to_ego_mask)
        not_masked.append(
            torch.from_numpy(ego_mask.astype(float)).cuda().float())

    # Learning variables
    extra_trans_m = [
        torch.autograd.Variable(torch.zeros(3).cuda(), requires_grad=True)
        for _ in range(N_cams)
    ]
    extra_rot_deg = [
        torch.autograd.Variable(torch.zeros(3).cuda(), requires_grad=True)
        for _ in range(N_cams)
    ]

    # Constraints: translation
    frozen_cams_trans = args.frozen_cams_trans
    if frozen_cams_trans is not None:
        for i_cam in frozen_cams_trans:
            extra_trans_m[i_cam].requires_grad = False
    # Constraints: rotation
    frozen_cams_rot = args.frozen_cams_rot
    if frozen_cams_rot is not None:
        for i_cam in frozen_cams_rot:
            extra_rot_deg[i_cam].requires_grad = False

    # Parameters from argument parser
    save_pictures = args.save_pictures
    n_epochs = args.n_epochs
    learning_rate = args.lr
    step_size = args.scheduler_step_size
    gamma = args.scheduler_gamma

    # Table of loss
    loss_tab = np.zeros(n_epochs)

    # Table of extra rotation values
    extra_rot_values_tab = np.zeros((N_cams * 3, N_files * n_epochs))

    # Table of extra translation values
    extra_trans_values_tab = np.zeros((N_cams * 3, N_files * n_epochs))

    # Optimizer
    optimizer = optim.Adam(extra_trans_m + extra_rot_deg, lr=learning_rate)

    # Scheduler
    scheduler = optim.lr_scheduler.StepLR(optimizer,
                                          step_size=step_size,
                                          gamma=gamma)

    # Regularization weights
    regul_weight_trans = torch.tensor(args.regul_weight_trans).cuda()
    regul_weight_rot = torch.tensor(args.regul_weight_rot).cuda()
    regul_weight_overlap = torch.tensor(args.regul_weight_overlap).cuda()

    # Loop on the number of epochs
    count = 0
    for epoch in range(n_epochs):
        print('Epoch ' + str(epoch) + '/' + str(n_epochs))

        # Initialize loss
        loss_sum = 0

        # Loop on the number of files
        for i_file in range(N_files):

            print('')
            # Filename for camera 0
            base_0, ext_0 = os.path.splitext(
                os.path.basename(input_files[0][i_file]))
            print(base_0)

            # Initialize list of tensors: images, predicted inverse depths and predicted depths
            images, pred_inv_depths, pred_depths = [], [], []
            input_depth_files, has_gt_depth, gt_depth, gt_inv_depth = [], [], [], []
            nb_gt_depths = 0

            # Reset camera poses Twc
            CameraMultifocal.Twc.fget.cache_clear()

            # Loop on cams and predict depth
            for i_cam in range(N_cams):
                images.append(
                    load_image(input_files[i_cam][i_file]).convert('RGB'))
                images[i_cam] = resize_image(images[i_cam], image_shape)
                images[i_cam] = to_tensor(images[i_cam]).unsqueeze(0)
                if torch.cuda.is_available():
                    images[i_cam] = images[i_cam].to('cuda:{}'.format(rank()))
                with torch.no_grad():
                    pred_inv_depths.append(model_wrappers[i_cam].depth(
                        images[i_cam]))
                    pred_depths.append(inv2depth(pred_inv_depths[i_cam]))

                if args.use_lidar:
                    input_depth_files.append(
                        get_depth_file(input_files[i_cam][i_file],
                                       args.depth_suffix))
                    has_gt_depth.append(
                        os.path.exists(input_depth_files[i_cam]))
                    if has_gt_depth[i_cam]:
                        nb_gt_depths += 1
                        gt_depth.append(
                            np.load(input_depth_files[i_cam])
                            ['velodyne_depth'].astype(np.float32))
                        gt_depth[i_cam] = torch.from_numpy(
                            gt_depth[i_cam]).unsqueeze(0).unsqueeze(0)
                        gt_inv_depth.append(depth2inv(gt_depth[i_cam]))
                        if torch.cuda.is_available():
                            gt_depth[i_cam] = gt_depth[i_cam].to(
                                'cuda:{}'.format(rank()))
                            gt_inv_depth[i_cam] = gt_inv_depth[i_cam].to(
                                'cuda:{}'.format(rank()))
                    else:
                        gt_depth.append(0)
                        gt_inv_depth.append(0)

                # Apply correction on cams
                pose_matrix = get_extrinsics_pose_matrix_extra_trans_rot_torch(
                    input_files[i_cam][i_file], calib_data,
                    extra_trans_m[i_cam], extra_rot_deg[i_cam]).unsqueeze(0)
                pose_tensor = Pose(pose_matrix).to('cuda:{}'.format(rank()))
                cams[i_cam].Tcw = pose_tensor

            # Define a loss function between 2 images
            def photo_loss_2imgs(i_cam1, i_cam2, save_pictures):
                # Computes the photometric loss between 2 images of adjacent cameras
                # It reconstructs each image from the adjacent one, applying correction in rotation and translation

                # Reconstruct 3D points for each cam
                world_points1 = cams[i_cam1].reconstruct(pred_depths[i_cam1],
                                                         frame='w')
                world_points2 = cams[i_cam2].reconstruct(pred_depths[i_cam2],
                                                         frame='w')

                # Get coordinates of projected points on other cam
                ref_coords1to2 = cams[i_cam2].project(world_points1, frame='w')
                ref_coords2to1 = cams[i_cam1].project(world_points2, frame='w')

                # Reconstruct each image from the adjacent camera
                reconstructedImg2to1 = funct.grid_sample(images[i_cam2] *
                                                         not_masked[i_cam2],
                                                         ref_coords1to2,
                                                         mode='bilinear',
                                                         padding_mode='zeros',
                                                         align_corners=True)
                reconstructedImg1to2 = funct.grid_sample(images[i_cam1] *
                                                         not_masked[i_cam1],
                                                         ref_coords2to1,
                                                         mode='bilinear',
                                                         padding_mode='zeros',
                                                         align_corners=True)
                # Save pictures if requested
                if save_pictures:
                    # Save original files if first epoch
                    if epoch == 0:
                        cv2.imwrite(
                            args.save_folder + '/cam_' + str(i_cam1) +
                            '_file_' + str(i_file) + '_orig.png',
                            (images[i_cam1][0].permute(1, 2, 0)
                             )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255)
                        cv2.imwrite(
                            args.save_folder + '/cam_' + str(i_cam2) +
                            '_file_' + str(i_file) + '_orig.png',
                            (images[i_cam2][0].permute(1, 2, 0)
                             )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255)
                    # Save reconstructed images
                    cv2.imwrite(
                        args.save_folder + '/epoch_' + str(epoch) + '_file_' +
                        str(i_file) + '_cam_' + str(i_cam1) + '_recons_from_' +
                        str(i_cam2) + '.png',
                        ((reconstructedImg2to1 *
                          not_masked[i_cam1])[0].permute(1, 2, 0)
                         )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255)
                    cv2.imwrite(
                        args.save_folder + '/epoch_' + str(epoch) + '_file_' +
                        str(i_file) + '_cam_' + str(i_cam2) + '_recons_from_' +
                        str(i_cam1) + '.png',
                        ((reconstructedImg1to2 *
                          not_masked[i_cam2])[0].permute(1, 2, 0)
                         )[:, :, [2, 1, 0]].detach().cpu().numpy() * 255)

                # L1 loss
                l1_loss_1 = torch.abs(images[i_cam1] * not_masked[i_cam1] -
                                      reconstructedImg2to1 *
                                      not_masked[i_cam1])
                l1_loss_2 = torch.abs(images[i_cam2] * not_masked[i_cam2] -
                                      reconstructedImg1to2 *
                                      not_masked[i_cam2])

                # SSIM loss
                ssim_loss_weight = 0.85
                ssim_loss_1 = SSIM(images[i_cam1] * not_masked[i_cam1],
                                   reconstructedImg2to1 * not_masked[i_cam1],
                                   C1=1e-4,
                                   C2=9e-4,
                                   kernel_size=3)
                ssim_loss_2 = SSIM(images[i_cam2] * not_masked[i_cam2],
                                   reconstructedImg1to2 * not_masked[i_cam2],
                                   C1=1e-4,
                                   C2=9e-4,
                                   kernel_size=3)

                ssim_loss_1 = torch.clamp((1. - ssim_loss_1) / 2., 0., 1.)
                ssim_loss_2 = torch.clamp((1. - ssim_loss_2) / 2., 0., 1.)

                # Photometric loss: alpha * ssim + (1 - alpha) * l1
                photometric_loss_1 = ssim_loss_weight * ssim_loss_1.mean(
                    1, True) + (1 - ssim_loss_weight) * l1_loss_1.mean(
                        1, True)
                photometric_loss_2 = ssim_loss_weight * ssim_loss_2.mean(
                    1, True) + (1 - ssim_loss_weight) * l1_loss_2.mean(
                        1, True)

                # Compute the number of valid pixels
                mask1 = (reconstructedImg2to1 * not_masked[i_cam1]).sum(
                    axis=1, keepdim=True) != 0
                s1 = mask1.sum().float()
                mask2 = (reconstructedImg1to2 * not_masked[i_cam2]).sum(
                    axis=1, keepdim=True) != 0
                s2 = mask2.sum().float()

                # Compute the photometric losses weighed by the number of valid pixels
                loss_1 = (photometric_loss_1 *
                          mask1).sum() / s1 if s1 > 0 else 0
                loss_2 = (photometric_loss_2 *
                          mask2).sum() / s2 if s2 > 0 else 0

                # The final loss can be regularized to encourage a similar overlap between images
                if s1 > 0 and s2 > 0:
                    return loss_1 + loss_2 + regul_weight_overlap * image_area * (
                        1 / s1 + 1 / s2)
                else:
                    return 0.

            def lidar_loss(i_cam1, save_pictures):
                if args.use_lidar and has_gt_depth[i_cam1]:
                    mask_zeros_lidar = (
                        gt_depth[i_cam1][0, 0, :, :] == 0).detach()

                    # Ground truth sparse depth maps were generated using the untouched camera extrinsics
                    world_points_gt_oldCalib = cams_untouched[
                        i_cam1].reconstruct(gt_depth[i_cam1], frame='w')
                    world_points_gt_oldCalib[0, 0, mask_zeros_lidar] = 0.

                    # Get coordinates of projected points on new cam
                    ref_coords = cams[i_cam1].project(world_points_gt_oldCalib,
                                                      frame='w')
                    ref_coords[0, mask_zeros_lidar, :] = 0.

                    # Reconstruct projected lidar from the new camera
                    reprojected_gt_inv_depth = funct.grid_sample(
                        gt_inv_depth[i_cam1],
                        ref_coords,
                        mode='nearest',
                        padding_mode='zeros',
                        align_corners=True)
                    reprojected_gt_inv_depth[0, 0, mask_zeros_lidar] = 0.

                    mask_reprojected = (reprojected_gt_inv_depth > 0.).detach()
                    if save_pictures:
                        mask_reprojected_numpy = mask_reprojected[
                            0, 0, :, :].cpu().numpy()
                        u = np.where(mask_reprojected_numpy)[0]
                        v = np.where(mask_reprojected_numpy)[1]
                        n_lidar = u.size
                        reprojected_gt_depth_numpy = inv2depth(
                            reprojected_gt_inv_depth)[
                                0, 0, :, :].detach().cpu().numpy()

                        im = (images[i_cam1][0].permute(
                            1, 2, 0))[:, :,
                                      [2, 1, 0]].detach().cpu().numpy() * 255
                        dmax = 100.
                        for i_l in range(n_lidar):
                            d = reprojected_gt_depth_numpy[u[i_l], v[i_l]]
                            s = int((8 / d)) + 1
                            im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s,
                               0] = np.clip(
                                   np.power(d / dmax, .7) * 255, 10, 245)
                            im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s,
                               1] = np.clip(
                                   np.power((dmax - d) / dmax, 4.0) * 255, 10,
                                   245)
                            im[u[i_l] - s:u[i_l] + s, v[i_l] - s:v[i_l] + s,
                               2] = np.clip(
                                   np.power(np.abs(2 * (d - .5 * dmax) / dmax),
                                            3.0) * 255, 10, 245)

                        cv2.imwrite(
                            args.save_folder + '/epoch_' + str(epoch) +
                            '_file_' + str(i_file) + '_cam_' + str(i_cam1) +
                            '_lidar.png', im)

                    if mask_reprojected.sum() > 0:
                        return l1_lidar_loss(
                            pred_inv_depths[i_cam1] * not_masked[i_cam1],
                            reprojected_gt_inv_depth * not_masked[i_cam1])
                    else:
                        return 0.
                else:
                    return 0.

            if nb_gt_depths > 0:
                final_lidar_weight = (N_cams /
                                      nb_gt_depths) * args.lidar_weight
            else:
                final_lidar_weight = 0.

            # The final loss consists of summing the photometric loss of all pairs of adjacent cameras
            # and is regularized to prevent weights from exploding
            photo_loss = 1.0 * sum([
                photo_loss_2imgs(p[0], p[1], save_pictures)
                for p in camera_context_pairs
            ])
            regul_rot_loss = regul_weight_rot * sum(
                [(extra_rot_deg[i]**2).sum() for i in range(N_cams)])
            regul_trans_loss = regul_weight_trans * sum(
                [(extra_trans_m[i]**2).sum() for i in range(N_cams)])
            lidar_gt_loss = final_lidar_weight * sum(
                [lidar_loss(i, save_pictures) for i in range(N_cams)])

            loss = photo_loss + regul_rot_loss + regul_trans_loss + lidar_gt_loss

            with torch.no_grad():
                extra_rot_deg_before = []
                extra_trans_m_before = []
                for i_cam in range(N_cams):
                    extra_rot_deg_before.append(extra_rot_deg[i_cam].clone())
                    extra_trans_m_before.append(extra_trans_m[i_cam].clone())

            # Optimization steps
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            with torch.no_grad():
                extra_rot_deg_after = []
                extra_trans_m_after = []
                for i_cam in range(N_cams):
                    extra_rot_deg_after.append(extra_rot_deg[i_cam].clone())
                    extra_trans_m_after.append(extra_trans_m[i_cam].clone())

                rot_change_file = 0.
                trans_change_file = 0.

                for i_cam in range(N_cams):
                    rot_change_file += torch.abs(
                        extra_rot_deg_after[i_cam] -
                        extra_rot_deg_before[i_cam]).mean().item()
                    trans_change_file += torch.abs(
                        extra_trans_m_after[i_cam] -
                        extra_trans_m_before[i_cam]).mean().item()

                rot_change_file /= N_cams
                trans_change_file /= N_cams

                print('Average rotation change (deg.): ' +
                      "{:.4f}".format(rot_change_file))
                print('Average translation change (m.): ' +
                      "{:.4f}".format(trans_change_file))

            # Save correction values and print loss
            with torch.no_grad():
                loss_sum += loss.item()
                for i_cam in range(N_cams):
                    for j in range(3):
                        if optimize_rotation:
                            extra_rot_values_tab[
                                3 * i_cam + j,
                                count] = extra_rot_deg[i_cam][j].item()
                        if optimize_translation:
                            extra_trans_values_tab[
                                3 * i_cam + j,
                                count] = extra_trans_m[i_cam][j].item()
                print('Loss: ' + "{:.3f}".format(loss.item()) \
                      + ' (photometric: ' + "{:.3f}".format(photo_loss.item()) \
                      + ', rotation reg.: ' + "{:.4f}".format(regul_rot_loss.item()) \
                      + ', translation reg.: ' + "{:.4f}".format(regul_trans_loss.item())
                      + ', lidar: ' + "{:.3f}".format(lidar_gt_loss) +')')
                if nb_gt_depths > 0:
                    print('Number of ground truth lidar maps: ' +
                          str(nb_gt_depths))

            count += 1

        # Update scheduler
        print('Epoch:', epoch, 'LR:', scheduler.get_lr())
        scheduler.step()
        with torch.no_grad():
            print('End of epoch')
            if optimize_translation:
                print('New translation correction values: ')
                print(extra_trans_m)
            if optimize_rotation:
                print('New rotation correction values: ')
                print(extra_rot_deg)
            print('Average rotation change in epoch:')

        loss_tab[epoch] = loss_sum / N_files

    # Plot/save loss if requested
    plt.figure()
    plt.plot(loss_tab)
    if args.show_plots:
        plt.show()
    if args.save_plots:
        plt.savefig(
            os.path.join(args.save_folder,
                         get_sequence_name(input_files[0][0]) + '_loss.png'))

    # Plot/save correction values if requested
    if optimize_rotation:
        plt.figure()
        for j in range(N_cams * 3):
            plt.plot(extra_rot_values_tab[j])
        if args.show_plots:
            plt.show()
        if args.save_plots:
            plt.savefig(
                os.path.join(
                    args.save_folder,
                    get_sequence_name(input_files[0][0]) + '_extra_rot.png'))

    if optimize_translation:
        plt.figure()
        for j in range(N_cams * 3):
            plt.plot(extra_trans_values_tab[j])
        if args.show_plots:
            plt.show()
        if args.save_plots:
            plt.savefig(
                os.path.join(
                    args.save_folder,
                    get_sequence_name(input_files[0][0]) + '_extra_trans.png'))

    # Save correction values table if requested
    if args.save_rot_tab:
        np.save(
            os.path.join(args.save_folder,
                         get_sequence_name(input_files[0][0]) +
                         '_rot_tab.npy'), extra_rot_values_tab)