def show_epipolar_rui(x1, x2, img1_rgb, img2_rgb, F_gt, im_shape):
    N_points = x1.shape[0]
    x1_homo = utils_misc.homo_np(x1)
    x2_homo = utils_misc.homo_np(x2)
    right_P = np.matmul(F_gt, x1_homo.T)
    right_epipolar_x = np.tile(np.array([[0], [1]]), N_points) * im_shape[1]
    # Using the eqn of line: ax+by+c=0; y = (-c-ax)/b, http://ai.stanford.edu/~mitul/cs223b/draw_epipolar.m
    right_epipolar_y = (-right_P[2:3, :] -
                        right_P[0:1, :] * right_epipolar_x) / right_P[1:2, :]

    colors = np.random.rand(x2.shape[0])
    plt.figure(figsize=(30, 8))
    plt.subplot(121)
    plt.imshow(img1_rgb,
               cmap=None if len(img1_rgb.shape) == 3 else plt.get_cmap('gray'))
    plt.scatter(x1[:, 0], x1[:, 1], s=50, c=colors, edgecolors='w')
    plt.subplot(122)
    # plt.figure(figsize=(30, 8))
    plt.imshow(img2_rgb,
               cmap=None if len(img2_rgb.shape) == 3 else plt.get_cmap('gray'))
    plt.plot(right_epipolar_x, right_epipolar_y)
    plt.scatter(x2[:, 0], x2[:, 1], s=50, c=colors, edgecolors='w')
    plt.xlim(0, im_shape[1] - 1)
    plt.ylim(im_shape[0] - 1, 0)
    plt.show()
def show_epipolar_normalized(x1, x2, img1_rgb, img2_rgb, F_gt, im_shape):
    N_points = x1.shape[0]
    x1_homo = utils_misc.homo_np(x1)
    x2_homo = utils_misc.homo_np(x2)
    right_P = np.matmul(F_gt, x1_homo.T)
    right_epipolar_x = np.tile(np.array([[-1.], [1.]]), N_points) * im_shape[1]
    # Using the eqn of line: ax+by+c=0; y = (-c-ax)/b, http://ai.stanford.edu/~mitul/cs223b/draw_epipolar.m
    right_epipolar_y = (-right_P[2:3, :] -
                        right_P[0:1, :] * right_epipolar_x) / right_P[1:2, :]

    colors = np.random.rand(x2.shape[0])
    plt.figure(figsize=(30, 8))
    plt.subplot(121)
    #     plt.imshow(img1_rgb)
    #     plt.scatter(x1[:, 0]*f+W/2., x1[:, 1]*f+H/2., s=50, c=colors, edgecolors='w')
    plt.scatter(x1[:, 0], x1[:, 1], s=50, c=colors, edgecolors='w')
    plt.xlim(-im_shape[1], im_shape[1])
    plt.ylim(im_shape[0], -im_shape[0])
    plt.gca().set_aspect('equal', adjustable='box')

    plt.subplot(122)
    #     plt.imshow(img2_rgb)
    plt.plot(right_epipolar_x, right_epipolar_y)
    plt.scatter(x2[:, 0], x2[:, 1], s=50, c=colors, edgecolors='w')
    #     plt.axis('equal')
    plt.xlim(-im_shape[1], im_shape[1])
    plt.ylim(im_shape[0], -im_shape[0])
    plt.gca().set_aspect('equal', adjustable='box')
    plt.show()
    def construct_sample(self, scene_data, idx, frame_id, show_zoom_info):
        img, zoom_xy, img_ori = self.load_image(scene_data, idx,
                                                show_zoom_info)
        # print(img.shape, img_ori.shape)
        sample = {"img": img, "id": frame_id}

        # get 3d points
        if self.get_X:
            # feed in intrinsics for TUM to extract depth
            velo = self.load_velo(
                scene_data, idx,
                scene_data["calibs"].get("P_rect_noScale", None))
            # print(f"velo: {velo.shape}")
            if velo is None:
                logging.error("0 velo in %s. Skipped." % scene_data["dir"])
            # change to homography
            velo_homo = utils_misc.homo_np(velo)
            logging.debug(f"velo_homo: {velo_homo.shape}")
            val_idxes, X_rect, X_cam0 = rectify(
                velo_homo, scene_data["calibs"])  # list, [N, 3]
            logging.debug(f"X_rect: {X_rect.shape}")
            logging.debug(f"X_cam0: {X_cam0.shape}")
            logging.debug(f"val_idxes: {len(val_idxes)}")

            sample["X_cam2_vis"] = X_rect[val_idxes].astype(np.float32)
            sample["X_cam0_vis"] = X_cam0[val_idxes].astype(np.float32)
        if self.get_pose:
            sample["pose"] = scene_data["poses"][idx].astype(np.float32)
        if self.get_sift:
            # logging.info('Getting sift for frame %d/%d.'%(idx, scene_data['N_frames']))
            kp, des = self.sift.detectAndCompute(
                img_ori, None)  ## IMPORTANT: normalize these points
            x_all = np.array([p.pt for p in kp])
            # print(zoom_xy)
            x_all = (x_all * np.array([[zoom_xy[0], zoom_xy[1]]])).astype(
                np.float32)
            # print(x_all.shape, np.amax(x_all, axis=0), np.amin(x_all, axis=0))
            if x_all.shape[0] != self.sift_num:
                choice = crop_or_pad_choice(x_all.shape[0],
                                            self.sift_num,
                                            shuffle=True)
                x_all = x_all[choice]
                des = des[choice]
            sample["sift_kp"] = x_all
            sample["sift_des"] = des
        if self.get_SP:
            img_ori_gray = cv2.cvtColor(img_ori, cv2.COLOR_RGB2GRAY)
            img = (torch.from_numpy(img_ori_gray).float().unsqueeze(
                0).unsqueeze(0).float() / 255.0)
            pts, desc, _, heatmap = self.fe.run(img)
            pts = pts[0].T  # [N, 3]
            pts[:, :2] = (pts[:, :2] *
                          np.array([[zoom_xy[0], zoom_xy[1]]])).astype(
                              np.float32)
            desc = desc[0].T  # [N, 256]
            sample["SP_kp"] = pts
            sample["SP_des"] = desc
        return sample
def epi_distance_np(F, X, Y, if_homo=False):
    # Not squared. https://arxiv.org/pdf/1706.07886.pdf
    if not if_homo:
        X = utils_misc.homo_np(X)
        Y = utils_misc.homo_np(Y)
    if len(X.shape) == 2:
        nominator = np.abs(np.diag(Y @ F @ X.T))
        Fx1 = F @ X.T
        Fx2 = F.T @ Y.T
        denom_recp_Y_to_FX = 1. / np.sqrt(Fx1[0]**2 + Fx1[1]**2)
        denom_recp_X_to_FY = 1. / np.sqrt(Fx2[0]**2 + Fx2[1]**2)
    else:
        nominator = np.abs(
            np.diagonal(np.transpose(Y @ F @ X, (1, 2)), axis=1, axis2=2))
        Fx1 = F @ np.transpose(X, (1, 2))
        Fx2 = np.transpose(F, (1, 2)) @ np.transpose(Y, (1, 2))
        denom_recp_Y_to_FX = 1. / np.sqrt(Fx1[:, 0]**2 + Fx1[:, 1]**2)
        denom_recp_X_to_FY = 1. / np.sqrt(Fx2[:, 0]**2 + Fx2[:, 1]**2)
        # print(nominator.size(), denom.size())
    dist1 = nominator * denom_recp_Y_to_FX
    dist2 = nominator * denom_recp_X_to_FY
    dist3 = nominator * (denom_recp_Y_to_FX + denom_recp_X_to_FY)
    # return (dist1+dist2)/2., dist1, dist2
    return dist3, dist1, dist2
def show_epipolar_rui_gtEst(x1,
                            x2,
                            img1_rgb,
                            img2_rgb,
                            F_gt,
                            F_est,
                            im_shape,
                            title_append=''):
    N_points = x1.shape[0]
    x1_homo = utils_misc.homo_np(x1)
    x2_homo = utils_misc.homo_np(x2)
    right_P = np.matmul(F_gt, x1_homo.T)
    right_epipolar_x = np.tile(np.array([[0], [1]]), N_points) * im_shape[1]
    # Using the eqn of line: ax+by+c=0; y = (-c-ax)/b, http://ai.stanford.edu/~mitul/cs223b/draw_epipolar.m
    right_epipolar_y = (-right_P[2:3, :] -
                        right_P[0:1, :] * right_epipolar_x) / right_P[1:2, :]

    # colors = get_spaced_colors(x2.shape[0])
    # colors = np.random.random((x2.shape[0], 3))
    plt.figure(figsize=(60, 8))
    plt.imshow(img2_rgb,
               cmap=None if len(img2_rgb.shape) == 3 else plt.get_cmap('gray'))

    plt.plot(right_epipolar_x, right_epipolar_y, 'b', linewidth=0.5)
    plt.scatter(x2[:, 0], x2[:, 1], s=50, edgecolors='w')

    right_P = np.matmul(F_est, x1_homo.T)
    right_epipolar_x = np.tile(np.array([[0], [1]]), N_points) * im_shape[1]
    right_epipolar_y = (-right_P[2:3, :] -
                        right_P[0:1, :] * right_epipolar_x) / right_P[1:2, :]
    plt.plot(right_epipolar_x, right_epipolar_y, 'r', linewidth=0.5)

    plt.xlim(0, im_shape[1] - 1)
    plt.ylim(im_shape[0] - 1, 0)
    plt.title('Blue lines for GT F; Red lines for est. F. -- ' + title_append)
    plt.show()
Exemple #6
0
 def rectify_all(self, visualize=False):
     # for each frame, get the visible points on front view with identity left camera, as well as indexes of points on both left/right images
     print('Rectifying...')
     self.val_idxes_list = []
     self.X_rect_list = []
     for i in range(self.N_frames):
         print(i, self.N_frames)
         velo = list(self.dataset.velo)[i]  # [N, 4]
         velo = velo[:, :3]
         velo_reproj = utils_misc.homo_np(velo)
         val_idxes, X_rect = self.rectify(velo_reproj,
                                          self.dataset_rgb[i][0],
                                          self.dataset_rgb[i][1],
                                          visualize=((i % 100 == 0)
                                                     & visualize))
         self.val_idxes_list.append(val_idxes)
         self.X_rect_list.append(X_rect)
     print('Finished rectifying all frames.')
     return self.val_idxes_list, self.X_rect_list
Exemple #7
0
    def show_demo(self):
        velo_reproj_list = []
        for i in range(self.N_frames):
            velo = list(self.dataset.velo)[i]  # [N, 4]
            # project the points to the camera
            velo = velo[:, :3]
            velo_reproj = utils_misc.homo_np(velo)
            velo_reproj_list.append(velo_reproj)

            for cam_iter, cam in enumerate(['leftRGB', 'rightRGB']):
                P_rect = self.P_rects[
                    cam]  # P_rect_0[0-3]: 3x4 projection matrix after rectification; the reprojection matrix in MV3D
                P_velo2im = np.dot(np.dot(P_rect, self.R_cam2rect),
                                   self.velo2cam)  # 4*3

                velo_pts_im = np.dot(P_velo2im, velo_reproj.T).T  # [*, 3]
                velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][
                    ..., np.newaxis]

                # check if in bounds
                # use minus 1 to get the exact same value as KITTI matlab code
                velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1
                velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1
                val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0)
                val_inds = val_inds & (velo_pts_im[:, 0] <
                                       self.im_shape[1]) & (velo_pts_im[:, 1] <
                                                            self.im_shape[0])
                velo_pts_im = velo_pts_im[val_inds, :]

                if i == 0:
                    print(
                        'Demo: Showing left/right data (unfiltered and unrectified) of the first frame.'
                    )
                    plt.figure(figsize=(30, 8))
                    plt.imshow(self.dataset_rgb[i][cam_iter])
                    plt.scatter(velo_pts_im[:, 0].astype(np.int),
                                velo_pts_im[:, 1].astype(np.int),
                                s=2,
                                c=1. / velo_pts_im[:, 2])
                    plt.xlim(0, self.im_shape[1] - 1)
                    plt.ylim(self.im_shape[0] - 1, 0)
                    plt.title(cam)
                    plt.show()
def reproj_and_scatter(Rt,
                       X_rect,
                       im_rgb,
                       kitti_two_frame_loader=None,
                       visualize=True,
                       title_appendix='',
                       param_list=[],
                       set_lim=False,
                       debug=True):
    if kitti_two_frame_loader is None:
        if debug:
            print('Reading from input list of param_list=[K, im_shape].')
        K = param_list[0]
        im_shape = param_list[1]
    else:
        K = kitti_two_frame_loader.K
        im_shape = kitti_two_frame_loader.im_shape

    x1_homo = np.matmul(K, np.matmul(Rt, utils_misc.homo_np(X_rect).T)).T
    x1 = x1_homo[:, 0:2] / x1_homo[:, 2:3]
    if visualize:
        plt.figure(figsize=(30, 8))
        cmap = None if len(
            np.array(im_rgb).shape) == 3 else plt.get_cmap('gray')
        plt.imshow(im_rgb, cmap=cmap)
        val_inds = scatter_xy(
            x1,
            x1_homo[:, 2],
            im_shape,
            'Reprojection to cam 2 with rectified X and camera_' +
            title_appendix,
            new_figure=False,
            set_lim=set_lim)
    else:
        val_inds = utils_misc.within(x1[:, 0], x1[:, 1], im_shape[1],
                                     im_shape[0])
    return val_inds, x1
Exemple #9
0
 def construct_sample(self, scene_data, idx, frame_id, show_zoom_info):
     img, zoom_xy, img_ori = self.load_image(scene_data, idx, show_zoom_info)
     # print(img.shape, img_ori.shape)
     sample = {"img":img, "id":frame_id}
     if self.get_X:
         velo = load_velo(scene_data, idx)
         if velo is None:
             logging.error('0 velo in %s. Skipped.'%scene_data['dir'])
         velo_homo = utils_misc.homo_np(velo)
         val_idxes, X_rect, X_cam0 = rectify(velo_homo, scene_data['calibs']) # list, [N, 3]
         sample['X_cam2_vis'] = X_rect[val_idxes].astype(np.float32)
         sample['X_cam0_vis'] = X_cam0[val_idxes].astype(np.float32)
     if self.get_pose:
         sample['pose'] = scene_data['poses'][idx].astype(np.float32)
     if self.get_sift:
         # logging.info('Getting sift for frame %d/%d.'%(idx, scene_data['N_frames']))
         kp, des = self.sift.detectAndCompute(img_ori, None) ## IMPORTANT: normalize these points
         x_all = np.array([p.pt for p in kp])
         # print(zoom_xy)
         x_all = (x_all * np.array([[zoom_xy[0], zoom_xy[1]]])).astype(np.float32)
         # print(x_all.shape, np.amax(x_all, axis=0), np.amin(x_all, axis=0))
         if x_all.shape[0] != self.sift_num:
             choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)
             x_all = x_all[choice]
             des = des[choice]
         sample['sift_kp'] = x_all
         sample['sift_des'] = des
     if self.get_SP:
         img_ori_gray = cv2.cvtColor(img_ori, cv2.COLOR_RGB2GRAY)
         img = torch.from_numpy(img_ori_gray).float().unsqueeze(0).unsqueeze(0).float() / 255.
         pts, desc, _, heatmap = self.fe.run(img)
         pts = pts[0].T # [N, 3]
         pts[:, :2] = (pts[:, :2] * np.array([[zoom_xy[0], zoom_xy[1]]])).astype(np.float32)
         desc = desc[0].T # [N, 256]
         sample['SP_kp'] = pts
         sample['SP_des'] = desc
     return sample
    def collect_scene_from_drive(self, drive_path):
        train_scenes = []
        for c in self.cam_ids:
            scene_data = {'cid': c, 'cid_num': self.cid_to_num[c], 'dir': Path(drive_path), 'rel_path': Path(drive_path).name + '_' + c}
            img_dir = os.path.join(drive_path, 'image_%d'%scene_data['cid_num'])
            scene_data['img_files'] = sorted(glob(img_dir + '/*.png'))
            scene_data['N_frames'] = len(scene_data['img_files'])
            scene_data['frame_ids'] = ['{:06d}'.format(i) for i in range(scene_data['N_frames'])]

            # Check images and optionally get SIFT
            img_shape = None
            zoom_xy = None
            if self.get_sift:
                logging.info('Getting SIFT...'+drive_path)
                scene_data['sift_kp'] = []
                scene_data['sift_des'] = []
            show_zoom_info = True
            for idx in tqdm(range(scene_data['N_frames'])):
                img, zoom_xy = self.load_image(scene_data, idx, show_zoom_info)
                show_zoom_info = False
                if img is None and idx==0:
                    logging.warning('0 images in %s. Skipped.'%drive_path)
                    return []
                else:
                    if img_shape is not None:
                        assert img_shape == img.shape, 'Inconsistent image shape in seq %s!'%drive_path
                    else:
                        img_shape = img.shape
                if self.get_sift:
                    # logging.info('Getting sift for frame %d/%d.'%(idx, scene_data['N_frames']))
                    kp, des = self.sift.detectAndCompute(img, None) ## IMPORTANT: normalize these points
                    x_all = np.array([p.pt for p in kp])
                    if x_all.shape[0] != self.sift_num:
                        choice = crop_or_pad_choice(x_all.shape[0], self.sift_num, shuffle=True)
                        x_all = x_all[choice]
                        des = des[choice]
                    scene_data['sift_kp'].append(x_all)
                    scene_data['sift_des'].append(des)
            if self.get_sift:
                assert scene_data['N_frames']==len(scene_data['sift_kp']), 'scene_data[N_frames]!=len(scene_data[sift_kp]), %d!=%d'%(scene_data['N_frames'], len(scene_data['sift_kp']))

            scene_data['calibs'] = {'im_shape': [img_shape[0], img_shape[1]], 'zoom_xy': zoom_xy, 'rescale': True if zoom_xy != (1., 1.) else False}

            # Get geo params from the RAW dataset calibs
            P_rect_ori_dict = self.get_P_rect(scene_data, scene_data['calibs'])
            intrinsics = P_rect_ori_dict[c][:,:3]
            calibs_rects = self.get_rect_cams(intrinsics, P_rect_ori_dict['02'])

            drive_in_raw = self.map_to_raw[drive_path[-2:]]
            date = drive_in_raw[:10]
            seq = drive_in_raw[-4:]
            calib_path_in_raw = Path(self.dataset_dir)/'raw'/date
            print('++++', calib_path_in_raw)
            imu2velo_dict = read_calib_file(calib_path_in_raw/'calib_imu_to_velo.txt')
            velo2cam_dict = read_calib_file(calib_path_in_raw/'calib_velo_to_cam.txt')
            cam2cam_dict = read_calib_file(calib_path_in_raw/'calib_cam_to_cam.txt')
            velo2cam_mat = transform_from_rot_trans(velo2cam_dict['R'], velo2cam_dict['T'])
            imu2velo_mat = transform_from_rot_trans(imu2velo_dict['R'], imu2velo_dict['T'])
            cam_2rect_mat = transform_from_rot_trans(cam2cam_dict['R_rect_00'], np.zeros(3))
            scene_data['calibs'].update({'K': intrinsics, 'P_rect_ori_dict': P_rect_ori_dict, 'cam_2rect': cam_2rect_mat, 'velo2cam': velo2cam_mat})
            scene_data['calibs'].update(calibs_rects)

            # Get pose
            poses = np.genfromtxt(self.dataset_dir/'poses'/'{}.txt'.format(drive_path[-2:])).astype(np.float64).reshape(-1, 3, 4)
            assert scene_data['N_frames']==poses.shape[0], 'scene_data[N_frames]!=poses.shape[0], %d!=%d'%(scene_data['N_frames'], poses.shape[0])
            scene_data['poses'] = poses

            scene_data['Rt_cam2_gt'] = scene_data['calibs']['Rtl_gt']

            # Get velo
            if self.get_X:
                logging.info('Getting X...'+drive_path)
                # for each frame, get the visible points on front view with identity left camera, as well as indexes of points on both left/right images
                val_idxes_list = []
                X_rect_list = []
                X_cam0_list = []
                for idx in tqdm(range(scene_data['N_frames'])):
                    velo = self.load_velo(scene_data, idx)
                    if velo is None:
                        break
                    velo_homo = utils_misc.homo_np(velo)
                    val_idxes, X_rect, X_cam0 = rectify(velo_homo, scene_data['calibs']) # list, [N, 3]
                    val_idxes_list.append(val_idxes)
                    X_rect_list.append(X_rect)
                    X_cam0_list.append(X_cam0)
                if velo is None and idx==0:
                    logging.warning('0 velo in %s. Skipped.'%drive_path)
                    return []
                scene_data['val_idxes'] = val_idxes_list
                scene_data['X_cam2'] = X_rect_list
                scene_data['X_cam0'] = X_cam0_list
                # Check number of velo frames
                assert scene_data['N_frames']==len(scene_data['X_cam2']), 'scene_data[N_frames]!=len(scene_data[X_cam2]), %d!=%d'%(scene_data['N_frames'], len(scene_data['X_cam2']))

            train_scenes.append(scene_data)
        return train_scenes
def draw_corr_widths_and_epi(F_gt,
                             im1,
                             im2,
                             x1,
                             x2,
                             linewidth,
                             title='',
                             rescale=True,
                             scale=1.):
    # im1 = img1_rgb
    # im2 = img2_rgb
    # x1 = x1_sample
    # x2 = x2_sample
    im_shape = im1.shape
    assert im1.shape == im2.shape, 'Shape mismatch between im1 and im2! @draw_corr()'
    x2_copy = x2.copy()
    x2_copy[:, 0] = x2_copy[:, 0] + im_shape[1]

    # lines2 = cv2.computeCorrespondEpilines(x1.reshape(-1,1,2).astype(int), 1,F_gt)
    # lines2 = lines2.reshape(-1,3)
    # im2, _, colors = drawlines(np.array(im2).copy(), np.array(im1).copy(), lines2, x2.astype(int), x1.astype(int), width=2)

    im12 = np.hstack((im1, im2))

    plt.figure(figsize=(60, 8))
    plt.imshow(im12)
    for i in range(x1.shape[0]):
        if rescale:
            width = 5 if linewidth[i] < 2 else 10
        else:
            width = linewidth[i] * scale
        p = plt.plot(np.vstack((x1[i, 0], x2_copy[i, 0])),
                     np.vstack((x1[i, 1], x2_copy[i, 1])),
                     linewidth=width,
                     marker='o',
                     markersize=8)

        print(p[0].get_color())
        N_points = x1.shape[0]
        x1_homo = utils_misc.homo_np(x1)
        x2_homo = utils_misc.homo_np(x2)
        right_P = np.matmul(F_gt, x1_homo.T)
        right_epipolar_x = np.tile(np.array([[0], [1]]),
                                   N_points) * im_shape[1]
        # Using the eqn of line: ax+by+c=0; y = (-c-ax)/b, http://ai.stanford.edu/~mitul/cs223b/draw_epipolar.m
        right_epipolar_y = (-right_P[2:3, :] - right_P[0:1, :] *
                            right_epipolar_x) / right_P[1:2, :]

        colors = np.random.rand(x2.shape[0])
        # plt.figure(figsize=(30, 8))
        # plt.subplot(121)
        # plt.imshow(img1_rgb)
        # plt.scatter(x1[:, 0], x1[:, 1], s=50, c=colors, edgecolors='w')
        # plt.subplot(122)
        # # plt.figure(figsize=(30, 8))
        # plt.imshow(img2_rgb)
        plt.plot(right_epipolar_x + im_shape[1], right_epipolar_y)
        # plt.scatter(x2[:, 0], x2[:, 1], s=50, c=colors, edgecolors='w')
    plt.xlim(0, im_shape[1] * 2 - 1)
    plt.ylim(im_shape[0] - 1, 0)
    # plt.show()

    plt.title(title, {'fontsize': 40})
    plt.show()
Exemple #12
0
    def eval_one_sample(self, sample):
        import torch
        import dsac_tools.utils_F as utils_F  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_opencv as utils_opencv  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_vis as utils_vis  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_misc as utils_misc  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_geo as utils_geo  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        from train_good_utils import val_rt, get_matches_from_SP

        # params
        config = self.config
        net_dict = self.net_dict
        if_SP = self.config["model"]["if_SP"]
        if_quality = self.config["model"]["if_quality"]
        device = self.device
        net_SP_helper = self.net_SP_helper

        task = "validating"
        imgs = sample["imgs"]  # [batch_size, H, W, 3]
        Ks = sample["K"].to(device)  # [batch_size, 3, 3]
        K_invs = sample["K_inv"].to(device)  # [batch_size, 3, 3]
        batch_size = Ks.size(0)
        scene_names = sample["scene_name"]
        frame_ids = sample["frame_ids"]
        scene_poses = sample[
            "relative_scene_poses"]  # list of sequence_length tensors, which with size [batch_size, 4, 4]; the first being identity, the rest are [[R; t], [0, 1]]
        if config["data"]["read_what"]["with_X"]:
            Xs = sample[
                "X_cam2s"]  # list of [batch_size, 3, Ni]; only support batch_size=1 because of variable points Ni for each sample
        # sift_kps, sift_deses = sample['sift_kps'], sample['sift_deses']
        assert sample["get_flags"]["have_matches"][0].numpy(
        ), "Did not find the corres files!"
        matches_all, matches_good = sample["matches_all"], sample[
            "matches_good"]
        quality_all, quality_good = sample["quality_all"], sample[
            "quality_good"]

        delta_Rtijs_4_4 = scene_poses[1].float(
        )  # [batch_size, 4, 4], asserting we have 2 frames where scene_poses[0] are all identities
        E_gts, F_gts = sample["E"], sample["F"]
        pts1_virt_normalizedK, pts2_virt_normalizedK = (
            sample["pts1_virt_normalized"].to(device),
            sample["pts2_virt_normalized"].to(device),
        )
        pts1_virt_ori, pts2_virt_ori = (
            sample["pts1_virt"].to(device),
            sample["pts2_virt"].to(device),
        )
        # pts1_virt_ori, pts2_virt_ori = sample['pts1_velo'].to(device), sample['pts2_velo'].to(device)

        # Get and Normalize points
        if if_SP:
            net_SP = net_dict["net_SP"]
            SP_processer, SP_tracker = (
                net_SP_helper["SP_processer"],
                net_SP_helper["SP_tracker"],
            )
            xs, offsets, quality = get_matches_from_SP(sample["imgs_grey"],
                                                       net_SP, SP_processer,
                                                       SP_tracker)
            matches_use = xs + offsets
            # matches_use = xs + offsets
            quality_use = quality
        else:
            # Get and Normalize points
            matches_use = matches_good  # [SWITCH!!!]
            quality_use = quality_good.to(
                device) if if_quality else None  # [SWITCH!!!]

        ## process x1, x2
        matches_use = matches_use.to(device)

        N_corres = matches_use.shape[
            1]  # 1311 for matches_good, 2000 for matches_all
        x1, x2 = (
            matches_use[:, :, :2],
            matches_use[:, :, 2:],
        )  # [batch_size, N, 2(W, H)]
        x1_normalizedK = utils_misc._de_homo(
            torch.matmul(
                torch.inverse(Ks),
                utils_misc._homo(x1).transpose(1, 2)).transpose(
                    1,
                    2))  # [batch_size, N, 2(W, H)], min/max_X=[-W/2/f, W/2/f]
        x2_normalizedK = utils_misc._de_homo(
            torch.matmul(
                torch.inverse(Ks),
                utils_misc._homo(x2).transpose(1, 2)).transpose(
                    1,
                    2))  # [batch_size, N, 2(W, H)], min/max_X=[-W/2/f, W/2/f]
        matches_use_normalizedK = torch.cat((x1_normalizedK, x2_normalizedK),
                                            2)

        matches_use_ori = torch.cat((x1, x2), 2)

        # Get image feats
        if config["model"]["if_img_feat"]:
            imgs = sample["imgs"]  # [batch_size, H, W, 3]
            imgs_stack = ((torch.cat(imgs, 3).float() - 127.5) /
                          127.5).permute(0, 3, 1, 2)

        qs_scene = sample["q_scene"].to(device)  # [B, 4, 1]
        ts_scene = sample["t_scene"].to(device)  # [B, 3, 1]
        qs_cam = sample["q_cam"].to(device)  # [B, 4, 1]
        ts_cam = sample["t_cam"].to(device)  # [B, 3, 1]

        t_scene_scale = torch.norm(ts_scene, p=2, dim=1, keepdim=True)

        # image_height, image_width = config['data']['image']['size'][0], config['data']['image']['size'][1]
        # mask_x1 = (matches_use_ori[:, :, 0] > (image_width/8.*3.)).byte() & (matches_use_ori[:, :, 0] < (image_width/8.*5.)).byte()
        # mask_x2 = (matches_use_ori[:, :, 2] > (image_width/8.*3.)).byte() & (matches_use_ori[:, :, 2] < (image_width/8.*5.)).byte()
        # mask_y1 = (matches_use_ori[:, :, 1] > (image_height/8.*3.)).byte() & (matches_use_ori[:, :, 1] < (image_height/8.*5.)).byte()
        # mask_y2 = (matches_use_ori[:, :, 3] > (image_height/8.*3.)).byte() & (matches_use_ori[:, :, 3] < (image_height/8.*5.)).byte()
        # mask_center = (~(mask_x1 & mask_y1)) & (~(mask_x2 & mask_y2))
        # matches_use_ori = (mask_center.float()).unsqueeze(-1) * matches_use_ori + torch.tensor([image_width/2., image_height/2., image_width/2., image_height/2.]).to(device).unsqueeze(0).unsqueeze(0) * (1- (mask_center.float()).unsqueeze(-1))
        # x1, x2 = matches_use_ori[:, :, :2], matches_use_ori[:, :, 2:] # [batch_size, N, 2(W, H)]

        data_batch = {
            "matches_xy_ori": matches_use_ori,
            "quality": quality_use,
            "x1_normalizedK": x1_normalizedK,
            "x2_normalizedK": x2_normalizedK,
            "Ks": Ks,
            "K_invs": K_invs,
            "matches_good_unique_nums": sample["matches_good_unique_nums"],
            "t_scene_scale": t_scene_scale,
        }
        # loss_params = {'model': config['model']['name'], 'clamp_at':config['model']['clamp_at'], 'depth': config['model']['depth']}
        loss_params = {
            "model": config["model"]["name"],
            "clamp_at": config["model"]["clamp_at"],
            "depth": config["model"]["depth"],
        }

        with torch.no_grad():
            outs = net_dict["net_deepF"](data_batch)

            pts1_eval, pts2_eval = pts1_virt_ori, pts2_virt_ori

            #     logits = outs['logits'] # [batch_size, N]
            #     logits_weights = F.softmax(logits, dim=1)
            logits_weights = outs["weights"]
            loss_E = 0.0

            F_out, T1, T2, out_a = (
                outs["F_est"],
                outs["T1"],
                outs["T2"],
                outs["out_layers"],
            )
            pts1_eval = torch.bmm(T1,
                                  pts1_virt_ori.permute(0, 2,
                                                        1)).permute(0, 2, 1)
            pts2_eval = torch.bmm(T2,
                                  pts2_virt_ori.permute(0, 2,
                                                        1)).permute(0, 2, 1)

            # pts1_eval = utils_misc._homo(F.normalize(pts1_eval[:, :, :2], dim=2))
            # pts2_eval = utils_misc._homo(F.normalize(pts2_eval[:, :, :2], dim=2))

            loss_layers = []
            losses_layers = []
            # losses = utils_F.compute_epi_residual(pts1_eval, pts2_eval, F_est, loss_params['clamp_at']) #- res.mean()
            # losses_layers.append(losses)
            # loss_all = losses.mean()
            # loss_layers.append(loss_all)
            out_a.append(F_out)
            loss_all = 0.0
            for iter in range(loss_params["depth"]):
                losses = utils_F.compute_epi_residual(pts1_eval, pts2_eval,
                                                      out_a[iter],
                                                      loss_params["clamp_at"])
                # losses = utils_F._YFX(pts1_eval, pts2_eval, out_a[iter], if_homo=True, clamp_at=loss_params['clamp_at'])
                losses_layers.append(losses)
                loss = losses.mean()
                loss_layers.append(loss)
                loss_all += loss

            loss_all = loss_all / len(loss_layers)

            F_ests = T2.permute(0, 2, 1).bmm(F_out.bmm(T1))
            E_ests = Ks.transpose(1, 2) @ F_ests @ Ks

            last_losses = losses_layers[-1].detach().cpu().numpy()
            print(last_losses)
            print(np.amax(last_losses, axis=1))

        # E_ests_list = []
        # for x1_single, x2_single, K, w in zip(x1, x2, Ks, logits_weights):
        #     E_est = utils_F._E_from_XY(x1_single, x2_single, K, torch.diag(w))
        #     E_ests_list.append(E_est)
        # E_ests = torch.stack(E_ests_list).to(device)
        # F_ests = utils_F._E_to_F(E_ests, Ks)
        K_np = Ks.cpu().detach().numpy()
        x1_np, x2_np = x1.cpu().detach().numpy(), x2.cpu().detach().numpy()
        E_est_np = E_ests.cpu().detach().numpy()
        F_est_np = F_ests.cpu().detach().numpy()
        delta_Rtijs_4_4_cpu_np = delta_Rtijs_4_4.cpu().numpy()

        # Tests and vis
        idx = 0
        img1 = imgs[0][idx].numpy().astype(np.uint8)
        img2 = imgs[1][idx].numpy().astype(np.uint8)
        img1_rgb, img2_rgb = img1, img2
        img1_rgb_np, img2_rgb_np = img1, img2
        im_shape = img1.shape
        x1 = x1_np[idx]
        x2 = x2_np[idx]
        #         utils_vis.draw_corr(img1, img2, x1, x2)

        delta_Rtij = delta_Rtijs_4_4_cpu_np[idx]
        print("----- delta_Rtij", delta_Rtij)
        delta_Rtij_inv = np.linalg.inv(delta_Rtij)
        K = K_np[idx]
        F_gt_th = F_gts[idx].cpu()
        F_gt = F_gt_th.numpy()
        E_gt_th = E_gts[idx].cpu()
        E_gt = E_gt_th.numpy()
        F_est = F_est_np[idx]
        E_est = E_est_np[idx]

        unique_rows_all, unique_rows_all_idxes = np.unique(np.hstack((x1, x2)),
                                                           axis=0,
                                                           return_index=True)
        mask_sample = np.random.choice(x1.shape[0], 100)
        angle_R = utils_geo.rot12_to_angle_error(np.eye(3),
                                                 delta_Rtij_inv[:3, :3])
        angle_t = utils_geo.vector_angle(np.array([[0.0], [0.0], [1.0]]),
                                         delta_Rtij_inv[:3, 3:4])
        print(
            ">>>>>>>>>>>>>>>> Between frames: The rotation angle (degree) %.4f, and translation angle (degree) %.4f"
            % (angle_R, angle_t))
        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[mask_sample],
            x2[mask_sample],
            linewidth=2.0,
            title="Sample of 100 corres.",
        )

        #         ## Baseline: 8-points
        #         M_8point, error_Rt_8point, mask2_8point, E_est_8point = utils_opencv.recover_camera_opencv(K, x1, x2, delta_Rtij_inv, five_point=False, threshold=0.01)

        ## Baseline: 5-points
        five_point = False
        M_opencv, error_Rt_opencv, mask2, E_return = utils_opencv.recover_camera_opencv(
            K, x1, x2, delta_Rtij_inv, five_point=five_point, threshold=0.01)

        if five_point:
            E_est_opencv = E_return
            F_est_opencv = utils_F.E_to_F_np(E_est_opencv, K)
        else:
            E_est_opencv, F_est_opencv = E_return[0], E_return[1]

        ## Check geo dists
        print(f"K: {K}")
        x1_normalizedK = utils_misc.de_homo_np(
            (np.linalg.inv(K) @ utils_misc.homo_np(x1).T).T)
        x2_normalizedK = utils_misc.de_homo_np(
            (np.linalg.inv(K) @ utils_misc.homo_np(x2).T).T)
        K_th = torch.from_numpy(K)
        F_gt_normalized = K_th.t(
        ) @ F_gt_th @ K_th  # Should be identical to E_gts[idx]

        geo_dists = utils_F._sym_epi_dist(
            F_gt_normalized,
            torch.from_numpy(x1_normalizedK),
            torch.from_numpy(x2_normalizedK),
        ).numpy()
        geo_thres = 1e-4
        mask_in = geo_dists < geo_thres
        mask_out = geo_dists >= geo_thres

        mask_sample = mask2
        print(mask2.shape)
        np.set_printoptions(precision=8, suppress=True)

        ## Ours: Some analysis
        print("----- Oursssssssssss")
        scores_ori = logits_weights.cpu().numpy().flatten()
        import matplotlib.pyplot as plt

        plt.hist(scores_ori, 100)
        plt.show()
        sort_idxes = np.argsort(scores_ori[unique_rows_all_idxes])[::-1]
        scores = scores_ori[unique_rows_all_idxes][sort_idxes]
        num_corr = 100
        mask_conf = sort_idxes[:num_corr]
        # mask_sample = np.array(range(x1.shape[0]))[mask_sample][:20]

        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[unique_rows_all_idxes],
            x2[unique_rows_all_idxes],
            linewidth=2.0,
            title=f"All {unique_rows_all_idxes.shape[0]} correspondences",
        )

        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[unique_rows_all_idxes][mask_conf, :],
            x2[unique_rows_all_idxes][mask_conf, :],
            linewidth=2.0,
            title=f"Ours top {num_corr} confidents",
        )
        #         print('(%d unique corres)'%scores.shape[0])
        utils_vis.show_epipolar_rui_gtEst(
            x2[unique_rows_all_idxes][mask_conf, :],
            x1[unique_rows_all_idxes][mask_conf, :],
            img2_rgb,
            img1_rgb,
            F_gt.T,
            F_est.T,
            weights=scores_ori[unique_rows_all_idxes][mask_conf],
            im_shape=im_shape,
            title_append="Ours top %d with largest score points" %
            mask_conf.shape[0],
        )
        print(f"F_gt: {F_gt/F_gt[2, 2]}")
        print(f"F_est: {F_est/F_est[2, 2]}")
        error_Rt_est_ours, epi_dist_mean_est_ours, _, _, _, _, _, M_estW = val_rt(
            idx,
            K,
            x1,
            x2,
            E_est,
            E_gt,
            F_est,
            F_gt,
            delta_Rtij,
            five_point=False,
            if_opencv=False,
        )
        print(
            "Recovered by ours (camera): The rotation error (degree) %.4f, and translation error (degree) %.4f"
            % (error_Rt_est_ours[0], error_Rt_est_ours[1]))
        #         print(epi_dist_mean_est_ours, np.mean(epi_dist_mean_est_ours))
        print("%.2f, %.2f" % (
            np.sum(epi_dist_mean_est_ours < 0.1) /
            epi_dist_mean_est_ours.shape[0],
            np.sum(epi_dist_mean_est_ours < 1) /
            epi_dist_mean_est_ours.shape[0],
        ))

        ## OpenCV: Some analysis
        corres = np.hstack((x1[mask_sample, :], x2[mask_sample, :]))

        unique_rows = np.unique(corres,
                                axis=0) if corres.shape[0] > 0 else corres

        opencv_name = "5-point" if five_point else "8-point"
        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[mask_sample, :],
            x2[mask_sample, :],
            linewidth=2.0,
            title=f"OpenCV {opencv_name} inliers",
        )

        print("----- OpenCV %s (%d unique inliers)" %
              (opencv_name, unique_rows.shape[0]))
        utils_vis.show_epipolar_rui_gtEst(
            x2[mask_sample, :],
            x1[mask_sample, :],
            img2_rgb,
            img1_rgb,
            F_gt.T,
            F_est_opencv.T,
            weights=scores_ori[mask_sample],
            im_shape=im_shape,
            title_append="OpenCV 5-point with its inliers",
        )
        print(F_gt / F_gt[2, 2])
        print(F_est_opencv / F_est_opencv[2, 2])
        error_Rt_est_5p, epi_dist_mean_est_5p, _, _, _, _, _, M_estOpenCV = val_rt(
            idx,
            K,
            x1,
            x2,
            E_est_opencv,
            E_gt,
            F_est_opencv,
            F_gt,
            delta_Rtij,
            five_point=False,
            if_opencv=False,
        )
        print(
            "Recovered by OpenCV %s (camera): The rotation error (degree) %.4f, and translation error (degree) %.4f"
            % (opencv_name, error_Rt_est_5p[0], error_Rt_est_5p[1]))
        print("%.2f, %.2f" % (
            np.sum(epi_dist_mean_est_5p < 0.1) / epi_dist_mean_est_5p.shape[0],
            np.sum(epi_dist_mean_est_5p < 1) / epi_dist_mean_est_5p.shape[0],
        ))
        # dict_of_lists['opencv5p'].append((np.sum(epi_dist_mean_est_5p<0.1)/epi_dist_mean_est_5p.shape[0], np.sum(epi_dist_mean_est_5p<1)/epi_dist_mean_est_5p.shape[0]))
        # dict_of_lists['ours'].append((np.sum(epi_dist_mean_est_ours<0.1)/epi_dist_mean_est_ours.shape[0], np.sum(epi_dist_mean_est_ours<1)/epi_dist_mean_est_ours.shape[0]))

        print("+++ GT, Opencv_5p, Ours")
        np.set_printoptions(precision=4, suppress=True)
        print(delta_Rtij_inv[:3])
        print(
            np.hstack((
                M_opencv[:, :3],
                M_opencv[:, 3:4] / M_opencv[2, 3] * delta_Rtij_inv[2, 3],
            )))
        print(
            np.hstack((M_estW[:, :3],
                       M_estW[:, 3:4] / M_estW[2, 3] * delta_Rtij_inv[2, 3])))

        return {
            "img1_rgb": img1_rgb,
            "img2_rgb": img2_rgb,
            "delta_Rtij": delta_Rtij
        }