def _sym_epi_dist(F, X, Y, if_homo=False, clamp_at=None):
    # Actually sauqred
    if not if_homo:
        X = utils_misc._homo(X)
        Y = utils_misc._homo(Y)
    if len(X.size()) == 2:
        nominator = (torch.diag(Y @ F @ X.t()))**2
        Fx1 = torch.mm(F, X.t())
        Fx2 = torch.mm(F.t(), Y.t())
        denom_recp = 1. / (Fx1[0]**2 + Fx1[1]**2) + 1. / (Fx2[0]**2 +
                                                          Fx2[1]**2)
    else:
        # print('-', X.detach().cpu().numpy())
        # print('-', Y.detach().cpu().numpy())
        # print('--', F.detach().cpu().numpy())
        nominator = (torch.diagonal(Y @ F @ X.transpose(1, 2), dim1=1,
                                    dim2=2))**2
        Fx1 = torch.matmul(F, X.transpose(1, 2))
        # print(Fx1.detach().cpu().numpy(), torch.max(Fx1), torch.sum(Fx1))
        # print(X.detach().cpu().numpy(), torch.max(X), torch.sum(X))
        Fx2 = torch.matmul(F.transpose(1, 2), Y.transpose(1, 2))
        denom_recp = 1. / (Fx1[:, 0]**2 + Fx1[:, 1]**2 +
                           1e-10) + 1. / (Fx2[:, 0]**2 + Fx2[:, 1]**2 + 1e-10)
        # print(nominator.size(), denom.size())

    errors = nominator * denom_recp
    # print('---', nominator.detach().cpu().numpy())
    # print('---------', denom_recp.detach().cpu().numpy())

    if clamp_at is not None:
        errors = torch.clamp(errors, max=clamp_at)

    return errors
def _normalize_XY(X, Y):
    """ The Hartley normalization. Following https://github.com/marktao99/python/blob/da2682f8832483650b85b0be295ae7eaf179fcc5/CVP/samples/sfm.py#L157 
    corrected with https://www.mathworks.com/matlabcentral/fileexchange/27541-fundamental-matrix-computation
    and https://en.wikipedia.org/wiki/Eight-point_algorithm#The_normalized_eight-point_algorithm """
    if X.size()[0] != Y.size()[0]:
        raise ValueError("Number of points don't match.")
    X = utils_misc._homo(X)
    mean_1 = torch.mean(X[:, :2], dim=0, keepdim=True)
    S1 = np.sqrt(2) / torch.mean(torch.norm(X[:, :2] - mean_1, 2, dim=1))
    # print(mean_1.size(), S1.size())
    T1 = torch.tensor(
        [[S1, 0, -S1 * mean_1[0, 0]], [0, S1, -S1 * mean_1[0, 1]], [0, 0, 1]],
        device=X.device)
    X_normalized = utils_misc._de_homo(torch.mm(
        T1, X.t()).t())  # ideally zero mean (x, y), and sqrt(2) average norm

    # xxx = X_normalized.numpy()
    # print(np.mean(xxx, axis=0))
    # print(np.mean(np.linalg.norm(xxx, 2, axis=1)))

    Y = utils_misc._homo(Y)
    mean_2 = torch.mean(Y[:, :2], dim=0, keepdim=True)
    S2 = np.sqrt(2) / torch.mean(torch.norm(Y[:, :2] - mean_2, 2, dim=1))
    T2 = torch.tensor(
        [[S2, 0, -S2 * mean_2[0, 0]], [0, S2, -S2 * mean_2[0, 1]], [0, 0, 1]],
        device=X.device)
    Y_normalized = utils_misc._de_homo(torch.mm(T2, Y.t()).t())

    return X_normalized, Y_normalized, T1, T2
def _sampson_dist(F, X, Y, if_homo=False):
    if not if_homo:
        X = utils_misc._homo(X)
        Y = utils_misc._homo(Y)
    if len(X.size()) == 2:
        nominator = (torch.diag(Y @ F @ X.t()))**2
        Fx1 = torch.mm(F, X.t())
        Fx2 = torch.mm(F.t(), Y.t())
        denom = Fx1[0]**2 + Fx1[1]**2 + Fx2[0]**2 + Fx2[1]**2
    else:
        nominator = (torch.diagonal(Y @ F @ X.transpose(1, 2), dim1=1,
                                    dim2=2))**2
        Fx1 = torch.matmul(F, X.transpose(1, 2))
        Fx2 = torch.matmul(F.transpose(1, 2), Y.transpose(1, 2))
        denom = Fx1[:, 0]**2 + Fx1[:, 1]**2 + Fx2[:, 0]**2 + Fx2[:, 1]**2
        # print(nominator.size(), denom.size())

    errors = nominator / denom
    return errors
def _normalize_XY_batch(X, Y):
    """ The Hartley normalization. Following https://github.com/marktao99/python/blob/da2682f8832483650b85b0be295ae7eaf179fcc5/CVP/samples/sfm.py#L157 
    corrected with https://www.mathworks.com/matlabcentral/fileexchange/27541-fundamental-matrix-computation
    and https://en.wikipedia.org/wiki/Eight-point_algorithm#The_normalized_eight-point_algorithm """
    # X: [batch_size, N, 2]
    if X.size()[1] != Y.size()[1]:
        raise ValueError("Number of points don't match.")
    X = utils_misc._homo(X)
    mean_1s = torch.mean(X[:, :, :2], dim=1, keepdim=True)
    S1s = np.sqrt(2) / torch.mean(torch.norm(X[:, :, :2] - mean_1s, 2, dim=2),
                                  dim=1)
    T1_list = []
    for S1, mean_1 in zip(S1s, mean_1s):
        T1_list.append(
            torch.tensor([[S1, 0, -S1 * mean_1[0, 0]],
                          [0, S1, -S1 * mean_1[0, 1]], [0, 0, 1]],
                         device=X.device))
    T1s = torch.stack(T1_list)
    X_normalized = utils_misc._de_homo(
        torch.bmm(T1s, X.transpose(1, 2)).transpose(
            1, 2))  # ideally zero mean (x, y), and sqrt(2) average norm

    # xxx = X_normalized.numpy()
    # print(np.mean(xxx, axis=0))
    # print(np.mean(np.linalg.norm(xxx, 2, axis=1)))

    Y = utils_misc._homo(Y)
    mean_2s = torch.mean(Y[:, :, :2], dim=1, keepdim=True)
    S2s = np.sqrt(2) / torch.mean(torch.norm(Y[:, :, :2] - mean_2s, 2, dim=2),
                                  dim=1)
    T2_list = []
    for S2, mean_2 in zip(S2s, mean_2s):
        T2_list.append(
            torch.tensor([[S2, 0, -S2 * mean_2[0, 0]],
                          [0, S2, -S2 * mean_2[0, 1]], [0, 0, 1]],
                         device=X.device))
    T2s = torch.stack(T2_list)
    Y_normalized = utils_misc._de_homo(
        torch.bmm(T2s, Y.transpose(1, 2)).transpose(1, 2))

    return X_normalized, Y_normalized, T1s, T2s
def _epi_distance(F, X, Y, if_homo=False):
    # Not squared. https://arxiv.org/pdf/1706.07886.pdf
    if not if_homo:
        X = utils_misc._homo(X)
        Y = utils_misc._homo(Y)
    if len(X.size()) == 2:
        nominator = torch.diag(Y @ F @ X.t()).abs()
        Fx1 = torch.mm(F, X.t())
        Fx2 = torch.mm(F.t(), Y.t())
        denom_recp_Y_to_FX = 1. / torch.sqrt(Fx1[0]**2 + Fx1[1]**2)
        denom_recp_X_to_FY = 1. / torch.sqrt(Fx2[0]**2 + Fx2[1]**2)
    else:
        nominator = (torch.diagonal(Y @ F @ X.transpose(1, 2), dim1=1,
                                    dim2=2)).abs()
        Fx1 = torch.matmul(F, X.transpose(1, 2))
        Fx2 = torch.matmul(F.transpose(1, 2), Y.transpose(1, 2))
        denom_recp_Y_to_FX = 1. / torch.sqrt(Fx1[:, 0]**2 + Fx1[:, 1]**2)
        denom_recp_X_to_FY = 1. / torch.sqrt(Fx2[:, 0]**2 + Fx2[:, 1]**2)
        # print(nominator.size(), denom.size())
    dist1 = nominator * denom_recp_Y_to_FX
    dist2 = nominator * denom_recp_X_to_FY
    return (dist1 + dist2) / 2., dist1, dist2
def _E_from_XY_batch(
    X,
    Y,
    K,
    W=None,
    if_normzliedK=False,
    normalize=True,
    show_debug=False
):  # Ref: https://github.com/marktao99/python/blob/master/CVP/samples/sfm.py#L55
    """ Normalized Eight Point Algorithom for E: [Manmohan] In practice, one would transform the data points by K^{-1}, then do a Hartley normalization, then estimate the F matrix (which is now E matrix), then set the singular value conditions, then denormalize. Note that it's better to set singular values first, then denormalize.
        X, Y: [N, 2] """
    assert X.dtype == torch.float32, 'batch_svd currently only supports torch.float32!'
    if if_normzliedK:
        X_normalizedK = X.float()
        Y_normalizedK = Y.float()
    else:
        X_normalizedK = utils_misc._de_homo(
            torch.bmm(torch.inverse(K),
                      utils_misc._homo(X).transpose(1,
                                                    2)).transpose(1,
                                                                  2)).float()
        Y_normalizedK = utils_misc._de_homo(
            torch.bmm(torch.inverse(K),
                      utils_misc._homo(Y).transpose(1,
                                                    2)).transpose(1,
                                                                  2)).float()

    # assert normalize==False, 'Not supported in batch mode yet!'
    if normalize:
        X, Y, T1, T2 = _normalize_XY_batch(X_normalizedK, Y_normalizedK)
    else:
        X, Y = X_normalizedK, Y_normalizedK

    # print(T1)
    # print(T2)
    # print(X)

    xx = torch.cat([X, Y], dim=2)
    XX = torch.stack([
        xx[:, :, 2] * xx[:, :, 0], xx[:, :, 2] * xx[:, :, 1], xx[:, :, 2],
        xx[:, :, 3] * xx[:, :, 0], xx[:, :, 3] * xx[:, :, 1], xx[:, :, 3],
        xx[:, :, 0], xx[:, :, 1],
        torch.ones_like(xx[:, :, 0])
    ],
                     dim=2)

    if W is not None:
        XX = torch.bmm(W, XX)
    # U, D, V = torch.svd(XX, some=False)
    # print(XX[0, :2])
    # print(XX.size())
    # U, D, V = batch_svd(XX)
    V_list = []
    for XX_single in XX:
        _, _, V_single = torch.svd(XX_single, some=True)
        V_list.append(V_single[:, -1])
    V_last_col = torch.stack(V_list)
    # print(V_last_col.size(), '----')

    if show_debug:
        print('[info.Debug @_E_from_XY] Singualr values of XX:\n',
              D[0].numpy())

    # F_recover = torch.reshape(V[:, :, -1], (-1, 3, 3))
    F_recover = V_last_col.view(-1, 3, 3)

    # FU, FD, FV= torch.svd(F_recover, some=False)
    FU, FD, FV = batch_svd(F_recover)

    if show_debug:
        print('[info.Debug @_E_from_XY] Singular values for recovered E(F):\n',
              FD[0].numpy())

    # FDnew = torch.diag(FD);
    # FDnew[2, 2] = 0;
    # F_recover_sing = torch.mm(FU, torch.mm(FDnew, FV.t()))
    S_110 = torch.diag(
        torch.tensor([1., 1., 0.], dtype=FU.dtype,
                     device=FU.device)).unsqueeze(0).expand(
                         FV.size()[0], -1, -1)

    E_recover_110 = torch.bmm(FU, torch.bmm(S_110, FV.transpose(1, 2)))
    # F_recover_sing_rescale = F_recover_sing / torch.norm(F_recover_sing) * torch.norm(F)
    # print(E_recover_110)
    if normalize:
        E_recover_110 = torch.bmm(T2.transpose(1, 2),
                                  torch.bmm(E_recover_110, T1))
    return -E_recover_110
def _E_from_XY(
    X,
    Y,
    K,
    W=None,
    if_normzliedK=False,
    normalize=True,
    show_debug=False
):  # Ref: https://github.com/marktao99/python/blob/master/CVP/samples/sfm.py#L55
    """ Normalized Eight Point Algorithom for E: [Manmohan] In practice, one would transform the data points by K^{-1}, then do a Hartley normalization, then estimate the F matrix (which is now E matrix), then set the singular value conditions, then denormalize. Note that it's better to set singular values first, then denormalize.
        X, Y: [N, 2] """
    if if_normzliedK:
        X_normalizedK = X
        Y_normalizedK = Y
    else:
        X_normalizedK = utils_misc._de_homo(
            torch.mm(torch.inverse(K),
                     utils_misc._homo(X).t()).t())
        Y_normalizedK = utils_misc._de_homo(
            torch.mm(torch.inverse(K),
                     utils_misc._homo(Y).t()).t())

    if normalize:
        X, Y, T1, T2 = _normalize_XY(X_normalizedK, Y_normalizedK)
    else:
        X, Y = X_normalizedK, Y_normalizedK
    # print(T1)
    # print(T2)
    # print(X)

    xx = torch.cat([X.t(), Y.t()], dim=0)
    XX = torch.stack([
        xx[2, :] * xx[0, :], xx[2, :] * xx[1, :], xx[2, :],
        xx[3, :] * xx[0, :], xx[3, :] * xx[1, :], xx[3, :], xx[0, :], xx[1, :],
        torch.ones_like(xx[0, :])
    ],
                     dim=0).t()  # [N, 9]
    # print(XX.size())
    if W is not None:
        XX = torch.mm(W, XX)  # [N, 9]
    # print(XX[:2])
    U, D, V = torch.svd(XX, some=True)
    if show_debug:
        print('[info.Debug @_E_from_XY] Singualr values of XX:\n', D.numpy())

    # U_np, D_np, V_np = np.linalg.svd(XX.numpy())

    F_recover = torch.reshape(V[:, -1], (3, 3))
    # print('-', F_recover)

    FU, FD, FV = torch.svd(F_recover, some=True)
    if show_debug:
        print('[info.Debug @_E_from_XY] Singular values for recovered E(F):\n',
              FD.numpy())

    # FDnew = torch.diag(FD);
    # FDnew[2, 2] = 0;
    # F_recover_sing = torch.mm(FU, torch.mm(FDnew, FV.t()))
    S_110 = torch.diag(
        torch.tensor([1., 1., 0.], dtype=FU.dtype, device=FU.device))
    E_recover_110 = torch.mm(FU, torch.mm(S_110, FV.t()))
    # F_recover_sing_rescale = F_recover_sing / torch.norm(F_recover_sing) * torch.norm(F)

    # print(E_recover_110)
    if normalize:
        E_recover_110 = torch.mm(T2.t(), torch.mm(E_recover_110, T1))
    return E_recover_110
Exemplo n.º 8
0
    def eval_one_sample(self, sample):
        import torch
        import dsac_tools.utils_F as utils_F  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_opencv as utils_opencv  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_vis as utils_vis  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_misc as utils_misc  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        import dsac_tools.utils_geo as utils_geo  # If cannot find: export KITTI_UTILS_PATH='/home/ruizhu/Documents/Projects/kitti_instance_RGBD_utils'
        from train_good_utils import val_rt, get_matches_from_SP

        # params
        config = self.config
        net_dict = self.net_dict
        if_SP = self.config["model"]["if_SP"]
        if_quality = self.config["model"]["if_quality"]
        device = self.device
        net_SP_helper = self.net_SP_helper

        task = "validating"
        imgs = sample["imgs"]  # [batch_size, H, W, 3]
        Ks = sample["K"].to(device)  # [batch_size, 3, 3]
        K_invs = sample["K_inv"].to(device)  # [batch_size, 3, 3]
        batch_size = Ks.size(0)
        scene_names = sample["scene_name"]
        frame_ids = sample["frame_ids"]
        scene_poses = sample[
            "relative_scene_poses"]  # list of sequence_length tensors, which with size [batch_size, 4, 4]; the first being identity, the rest are [[R; t], [0, 1]]
        if config["data"]["read_what"]["with_X"]:
            Xs = sample[
                "X_cam2s"]  # list of [batch_size, 3, Ni]; only support batch_size=1 because of variable points Ni for each sample
        # sift_kps, sift_deses = sample['sift_kps'], sample['sift_deses']
        assert sample["get_flags"]["have_matches"][0].numpy(
        ), "Did not find the corres files!"
        matches_all, matches_good = sample["matches_all"], sample[
            "matches_good"]
        quality_all, quality_good = sample["quality_all"], sample[
            "quality_good"]

        delta_Rtijs_4_4 = scene_poses[1].float(
        )  # [batch_size, 4, 4], asserting we have 2 frames where scene_poses[0] are all identities
        E_gts, F_gts = sample["E"], sample["F"]
        pts1_virt_normalizedK, pts2_virt_normalizedK = (
            sample["pts1_virt_normalized"].to(device),
            sample["pts2_virt_normalized"].to(device),
        )
        pts1_virt_ori, pts2_virt_ori = (
            sample["pts1_virt"].to(device),
            sample["pts2_virt"].to(device),
        )
        # pts1_virt_ori, pts2_virt_ori = sample['pts1_velo'].to(device), sample['pts2_velo'].to(device)

        # Get and Normalize points
        if if_SP:
            net_SP = net_dict["net_SP"]
            SP_processer, SP_tracker = (
                net_SP_helper["SP_processer"],
                net_SP_helper["SP_tracker"],
            )
            xs, offsets, quality = get_matches_from_SP(sample["imgs_grey"],
                                                       net_SP, SP_processer,
                                                       SP_tracker)
            matches_use = xs + offsets
            # matches_use = xs + offsets
            quality_use = quality
        else:
            # Get and Normalize points
            matches_use = matches_good  # [SWITCH!!!]
            quality_use = quality_good.to(
                device) if if_quality else None  # [SWITCH!!!]

        ## process x1, x2
        matches_use = matches_use.to(device)

        N_corres = matches_use.shape[
            1]  # 1311 for matches_good, 2000 for matches_all
        x1, x2 = (
            matches_use[:, :, :2],
            matches_use[:, :, 2:],
        )  # [batch_size, N, 2(W, H)]
        x1_normalizedK = utils_misc._de_homo(
            torch.matmul(
                torch.inverse(Ks),
                utils_misc._homo(x1).transpose(1, 2)).transpose(
                    1,
                    2))  # [batch_size, N, 2(W, H)], min/max_X=[-W/2/f, W/2/f]
        x2_normalizedK = utils_misc._de_homo(
            torch.matmul(
                torch.inverse(Ks),
                utils_misc._homo(x2).transpose(1, 2)).transpose(
                    1,
                    2))  # [batch_size, N, 2(W, H)], min/max_X=[-W/2/f, W/2/f]
        matches_use_normalizedK = torch.cat((x1_normalizedK, x2_normalizedK),
                                            2)

        matches_use_ori = torch.cat((x1, x2), 2)

        # Get image feats
        if config["model"]["if_img_feat"]:
            imgs = sample["imgs"]  # [batch_size, H, W, 3]
            imgs_stack = ((torch.cat(imgs, 3).float() - 127.5) /
                          127.5).permute(0, 3, 1, 2)

        qs_scene = sample["q_scene"].to(device)  # [B, 4, 1]
        ts_scene = sample["t_scene"].to(device)  # [B, 3, 1]
        qs_cam = sample["q_cam"].to(device)  # [B, 4, 1]
        ts_cam = sample["t_cam"].to(device)  # [B, 3, 1]

        t_scene_scale = torch.norm(ts_scene, p=2, dim=1, keepdim=True)

        # image_height, image_width = config['data']['image']['size'][0], config['data']['image']['size'][1]
        # mask_x1 = (matches_use_ori[:, :, 0] > (image_width/8.*3.)).byte() & (matches_use_ori[:, :, 0] < (image_width/8.*5.)).byte()
        # mask_x2 = (matches_use_ori[:, :, 2] > (image_width/8.*3.)).byte() & (matches_use_ori[:, :, 2] < (image_width/8.*5.)).byte()
        # mask_y1 = (matches_use_ori[:, :, 1] > (image_height/8.*3.)).byte() & (matches_use_ori[:, :, 1] < (image_height/8.*5.)).byte()
        # mask_y2 = (matches_use_ori[:, :, 3] > (image_height/8.*3.)).byte() & (matches_use_ori[:, :, 3] < (image_height/8.*5.)).byte()
        # mask_center = (~(mask_x1 & mask_y1)) & (~(mask_x2 & mask_y2))
        # matches_use_ori = (mask_center.float()).unsqueeze(-1) * matches_use_ori + torch.tensor([image_width/2., image_height/2., image_width/2., image_height/2.]).to(device).unsqueeze(0).unsqueeze(0) * (1- (mask_center.float()).unsqueeze(-1))
        # x1, x2 = matches_use_ori[:, :, :2], matches_use_ori[:, :, 2:] # [batch_size, N, 2(W, H)]

        data_batch = {
            "matches_xy_ori": matches_use_ori,
            "quality": quality_use,
            "x1_normalizedK": x1_normalizedK,
            "x2_normalizedK": x2_normalizedK,
            "Ks": Ks,
            "K_invs": K_invs,
            "matches_good_unique_nums": sample["matches_good_unique_nums"],
            "t_scene_scale": t_scene_scale,
        }
        # loss_params = {'model': config['model']['name'], 'clamp_at':config['model']['clamp_at'], 'depth': config['model']['depth']}
        loss_params = {
            "model": config["model"]["name"],
            "clamp_at": config["model"]["clamp_at"],
            "depth": config["model"]["depth"],
        }

        with torch.no_grad():
            outs = net_dict["net_deepF"](data_batch)

            pts1_eval, pts2_eval = pts1_virt_ori, pts2_virt_ori

            #     logits = outs['logits'] # [batch_size, N]
            #     logits_weights = F.softmax(logits, dim=1)
            logits_weights = outs["weights"]
            loss_E = 0.0

            F_out, T1, T2, out_a = (
                outs["F_est"],
                outs["T1"],
                outs["T2"],
                outs["out_layers"],
            )
            pts1_eval = torch.bmm(T1,
                                  pts1_virt_ori.permute(0, 2,
                                                        1)).permute(0, 2, 1)
            pts2_eval = torch.bmm(T2,
                                  pts2_virt_ori.permute(0, 2,
                                                        1)).permute(0, 2, 1)

            # pts1_eval = utils_misc._homo(F.normalize(pts1_eval[:, :, :2], dim=2))
            # pts2_eval = utils_misc._homo(F.normalize(pts2_eval[:, :, :2], dim=2))

            loss_layers = []
            losses_layers = []
            # losses = utils_F.compute_epi_residual(pts1_eval, pts2_eval, F_est, loss_params['clamp_at']) #- res.mean()
            # losses_layers.append(losses)
            # loss_all = losses.mean()
            # loss_layers.append(loss_all)
            out_a.append(F_out)
            loss_all = 0.0
            for iter in range(loss_params["depth"]):
                losses = utils_F.compute_epi_residual(pts1_eval, pts2_eval,
                                                      out_a[iter],
                                                      loss_params["clamp_at"])
                # losses = utils_F._YFX(pts1_eval, pts2_eval, out_a[iter], if_homo=True, clamp_at=loss_params['clamp_at'])
                losses_layers.append(losses)
                loss = losses.mean()
                loss_layers.append(loss)
                loss_all += loss

            loss_all = loss_all / len(loss_layers)

            F_ests = T2.permute(0, 2, 1).bmm(F_out.bmm(T1))
            E_ests = Ks.transpose(1, 2) @ F_ests @ Ks

            last_losses = losses_layers[-1].detach().cpu().numpy()
            print(last_losses)
            print(np.amax(last_losses, axis=1))

        # E_ests_list = []
        # for x1_single, x2_single, K, w in zip(x1, x2, Ks, logits_weights):
        #     E_est = utils_F._E_from_XY(x1_single, x2_single, K, torch.diag(w))
        #     E_ests_list.append(E_est)
        # E_ests = torch.stack(E_ests_list).to(device)
        # F_ests = utils_F._E_to_F(E_ests, Ks)
        K_np = Ks.cpu().detach().numpy()
        x1_np, x2_np = x1.cpu().detach().numpy(), x2.cpu().detach().numpy()
        E_est_np = E_ests.cpu().detach().numpy()
        F_est_np = F_ests.cpu().detach().numpy()
        delta_Rtijs_4_4_cpu_np = delta_Rtijs_4_4.cpu().numpy()

        # Tests and vis
        idx = 0
        img1 = imgs[0][idx].numpy().astype(np.uint8)
        img2 = imgs[1][idx].numpy().astype(np.uint8)
        img1_rgb, img2_rgb = img1, img2
        img1_rgb_np, img2_rgb_np = img1, img2
        im_shape = img1.shape
        x1 = x1_np[idx]
        x2 = x2_np[idx]
        #         utils_vis.draw_corr(img1, img2, x1, x2)

        delta_Rtij = delta_Rtijs_4_4_cpu_np[idx]
        print("----- delta_Rtij", delta_Rtij)
        delta_Rtij_inv = np.linalg.inv(delta_Rtij)
        K = K_np[idx]
        F_gt_th = F_gts[idx].cpu()
        F_gt = F_gt_th.numpy()
        E_gt_th = E_gts[idx].cpu()
        E_gt = E_gt_th.numpy()
        F_est = F_est_np[idx]
        E_est = E_est_np[idx]

        unique_rows_all, unique_rows_all_idxes = np.unique(np.hstack((x1, x2)),
                                                           axis=0,
                                                           return_index=True)
        mask_sample = np.random.choice(x1.shape[0], 100)
        angle_R = utils_geo.rot12_to_angle_error(np.eye(3),
                                                 delta_Rtij_inv[:3, :3])
        angle_t = utils_geo.vector_angle(np.array([[0.0], [0.0], [1.0]]),
                                         delta_Rtij_inv[:3, 3:4])
        print(
            ">>>>>>>>>>>>>>>> Between frames: The rotation angle (degree) %.4f, and translation angle (degree) %.4f"
            % (angle_R, angle_t))
        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[mask_sample],
            x2[mask_sample],
            linewidth=2.0,
            title="Sample of 100 corres.",
        )

        #         ## Baseline: 8-points
        #         M_8point, error_Rt_8point, mask2_8point, E_est_8point = utils_opencv.recover_camera_opencv(K, x1, x2, delta_Rtij_inv, five_point=False, threshold=0.01)

        ## Baseline: 5-points
        five_point = False
        M_opencv, error_Rt_opencv, mask2, E_return = utils_opencv.recover_camera_opencv(
            K, x1, x2, delta_Rtij_inv, five_point=five_point, threshold=0.01)

        if five_point:
            E_est_opencv = E_return
            F_est_opencv = utils_F.E_to_F_np(E_est_opencv, K)
        else:
            E_est_opencv, F_est_opencv = E_return[0], E_return[1]

        ## Check geo dists
        print(f"K: {K}")
        x1_normalizedK = utils_misc.de_homo_np(
            (np.linalg.inv(K) @ utils_misc.homo_np(x1).T).T)
        x2_normalizedK = utils_misc.de_homo_np(
            (np.linalg.inv(K) @ utils_misc.homo_np(x2).T).T)
        K_th = torch.from_numpy(K)
        F_gt_normalized = K_th.t(
        ) @ F_gt_th @ K_th  # Should be identical to E_gts[idx]

        geo_dists = utils_F._sym_epi_dist(
            F_gt_normalized,
            torch.from_numpy(x1_normalizedK),
            torch.from_numpy(x2_normalizedK),
        ).numpy()
        geo_thres = 1e-4
        mask_in = geo_dists < geo_thres
        mask_out = geo_dists >= geo_thres

        mask_sample = mask2
        print(mask2.shape)
        np.set_printoptions(precision=8, suppress=True)

        ## Ours: Some analysis
        print("----- Oursssssssssss")
        scores_ori = logits_weights.cpu().numpy().flatten()
        import matplotlib.pyplot as plt

        plt.hist(scores_ori, 100)
        plt.show()
        sort_idxes = np.argsort(scores_ori[unique_rows_all_idxes])[::-1]
        scores = scores_ori[unique_rows_all_idxes][sort_idxes]
        num_corr = 100
        mask_conf = sort_idxes[:num_corr]
        # mask_sample = np.array(range(x1.shape[0]))[mask_sample][:20]

        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[unique_rows_all_idxes],
            x2[unique_rows_all_idxes],
            linewidth=2.0,
            title=f"All {unique_rows_all_idxes.shape[0]} correspondences",
        )

        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[unique_rows_all_idxes][mask_conf, :],
            x2[unique_rows_all_idxes][mask_conf, :],
            linewidth=2.0,
            title=f"Ours top {num_corr} confidents",
        )
        #         print('(%d unique corres)'%scores.shape[0])
        utils_vis.show_epipolar_rui_gtEst(
            x2[unique_rows_all_idxes][mask_conf, :],
            x1[unique_rows_all_idxes][mask_conf, :],
            img2_rgb,
            img1_rgb,
            F_gt.T,
            F_est.T,
            weights=scores_ori[unique_rows_all_idxes][mask_conf],
            im_shape=im_shape,
            title_append="Ours top %d with largest score points" %
            mask_conf.shape[0],
        )
        print(f"F_gt: {F_gt/F_gt[2, 2]}")
        print(f"F_est: {F_est/F_est[2, 2]}")
        error_Rt_est_ours, epi_dist_mean_est_ours, _, _, _, _, _, M_estW = val_rt(
            idx,
            K,
            x1,
            x2,
            E_est,
            E_gt,
            F_est,
            F_gt,
            delta_Rtij,
            five_point=False,
            if_opencv=False,
        )
        print(
            "Recovered by ours (camera): The rotation error (degree) %.4f, and translation error (degree) %.4f"
            % (error_Rt_est_ours[0], error_Rt_est_ours[1]))
        #         print(epi_dist_mean_est_ours, np.mean(epi_dist_mean_est_ours))
        print("%.2f, %.2f" % (
            np.sum(epi_dist_mean_est_ours < 0.1) /
            epi_dist_mean_est_ours.shape[0],
            np.sum(epi_dist_mean_est_ours < 1) /
            epi_dist_mean_est_ours.shape[0],
        ))

        ## OpenCV: Some analysis
        corres = np.hstack((x1[mask_sample, :], x2[mask_sample, :]))

        unique_rows = np.unique(corres,
                                axis=0) if corres.shape[0] > 0 else corres

        opencv_name = "5-point" if five_point else "8-point"
        utils_vis.draw_corr(
            img1_rgb,
            img2_rgb,
            x1[mask_sample, :],
            x2[mask_sample, :],
            linewidth=2.0,
            title=f"OpenCV {opencv_name} inliers",
        )

        print("----- OpenCV %s (%d unique inliers)" %
              (opencv_name, unique_rows.shape[0]))
        utils_vis.show_epipolar_rui_gtEst(
            x2[mask_sample, :],
            x1[mask_sample, :],
            img2_rgb,
            img1_rgb,
            F_gt.T,
            F_est_opencv.T,
            weights=scores_ori[mask_sample],
            im_shape=im_shape,
            title_append="OpenCV 5-point with its inliers",
        )
        print(F_gt / F_gt[2, 2])
        print(F_est_opencv / F_est_opencv[2, 2])
        error_Rt_est_5p, epi_dist_mean_est_5p, _, _, _, _, _, M_estOpenCV = val_rt(
            idx,
            K,
            x1,
            x2,
            E_est_opencv,
            E_gt,
            F_est_opencv,
            F_gt,
            delta_Rtij,
            five_point=False,
            if_opencv=False,
        )
        print(
            "Recovered by OpenCV %s (camera): The rotation error (degree) %.4f, and translation error (degree) %.4f"
            % (opencv_name, error_Rt_est_5p[0], error_Rt_est_5p[1]))
        print("%.2f, %.2f" % (
            np.sum(epi_dist_mean_est_5p < 0.1) / epi_dist_mean_est_5p.shape[0],
            np.sum(epi_dist_mean_est_5p < 1) / epi_dist_mean_est_5p.shape[0],
        ))
        # dict_of_lists['opencv5p'].append((np.sum(epi_dist_mean_est_5p<0.1)/epi_dist_mean_est_5p.shape[0], np.sum(epi_dist_mean_est_5p<1)/epi_dist_mean_est_5p.shape[0]))
        # dict_of_lists['ours'].append((np.sum(epi_dist_mean_est_ours<0.1)/epi_dist_mean_est_ours.shape[0], np.sum(epi_dist_mean_est_ours<1)/epi_dist_mean_est_ours.shape[0]))

        print("+++ GT, Opencv_5p, Ours")
        np.set_printoptions(precision=4, suppress=True)
        print(delta_Rtij_inv[:3])
        print(
            np.hstack((
                M_opencv[:, :3],
                M_opencv[:, 3:4] / M_opencv[2, 3] * delta_Rtij_inv[2, 3],
            )))
        print(
            np.hstack((M_estW[:, :3],
                       M_estW[:, 3:4] / M_estW[2, 3] * delta_Rtij_inv[2, 3])))

        return {
            "img1_rgb": img1_rgb,
            "img2_rgb": img2_rgb,
            "delta_Rtij": delta_Rtij
        }