Ejemplo n.º 1
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_root', '-droot',  type=str, default='data/datasets_original/')
    parser.add_argument('--dataset', '-ds',  type=str, default='CambridgeLandmarks',
                        choices=['CambridgeLandmarks', '7Scenes'])
    parser.add_argument('--scenes', '-sc', type=str, nargs='*', default=None)
    parser.add_argument('--colmap_db_root', '-dbdir', type=str, default='data/colmap_dbs')
    parser.add_argument('--db_name', '-db', type=str, default='database.db')
    parser.add_argument('--pair_txt', '-pair', type=str, default='test_pairs.5nn.300cm50m.vlad.minmax.txt')
    parser.add_argument('--train_lbl_txt', type=str, default='dataset_train.txt')
    parser.add_argument('--test_lbl_txt', type=str, default='dataset_test.txt')
    parser.add_argument('--gpu', '-gpu', type=int, default=0)
    parser.add_argument('--cv_ransac_thres', type=float, nargs='*', default=[0.5])
    parser.add_argument('--loc_ransac_thres', type=float, nargs='*', default=[5])
    parser.add_argument('--output_root', '-odir', type=str, default='output/sift/')
    parser.add_argument('--log_txt', '-log', type=str, default='test_results.txt')

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
    scene_dict = {'CambridgeLandmarks' : ['KingsCollege', 'OldHospital', 'ShopFacade',  'StMarysChurch'],
                  '7Scenes' : ['chess', 'fire', 'heads', 'office', 'pumpkin', 'redkitchen', 'stairs']
                 }
    dataset = args.dataset
    scenes = scene_dict[dataset] if not args.scenes else args.scenes

    out_dir = os.path.join(args.output_root, dataset)
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)

    log_path = os.path.join(out_dir, args.log_txt)
    log = open(log_path, 'a') 
    lprint('Log results {}'.format(log_path))

    for rthres in args.cv_ransac_thres:
        lprint('>>>>Eval Sift CV Ransac : {} Pair: {}'.format(rthres, args.pair_txt), log)
        ransac_sav_path = os.path.join(out_dir, 'ransacs.rthres{}.npy'.format(rthres))

        # Predict essential matrixes in a structured data dict
        result_dict = predict_essential_matrix(args.data_root, dataset, scenes, 
                                               args.pair_txt, args.train_lbl_txt, args.test_lbl_txt,
                                               args.colmap_db_root, db_name=args.db_name,
                                               rthres=rthres, log=log)
        np.save(os.path.join(out_dir, 'preds.rthres{}.npy'.format(rthres)), result_dict)

        #Global localization via essential-matrix-based ransac                                   
        eval_pipeline_with_ransac(result_dict, log, ransac_thres=args.loc_ransac_thres, 
                                  ransac_iter=10, ransac_miu=1.414, pair_type='ess', 
                                  err_thres=[(0.25, 2), (0.5, 5), (5, 10)], 
                                  save_res_path=ransac_sav_path)
    log.close()
Ejemplo n.º 2
0
def test(net,
         config,
         log,
         test_loaders,
         sav_res_name=None,
         err_thres=[(0.25, 2), (0.5, 5), (5, 10)]):

    lprint('Testing Pairs: {} Err thres: {}'.format(config.pair_txt,
                                                    err_thres))
    if sav_res_name is not None:
        save_res_path = os.path.join(config.odir, sav_res_name)
    else:
        save_res_path = None
    print('Evaluate on datasets: {}'.format(test_loaders.keys()))

    net.eval()  # Switch to evaluation mode
    abs_err = (-1, -1)
    with torch.no_grad():
        t1 = time.time()
        result_dict = eval_prediction(test_loaders,
                                      net,
                                      log,
                                      pair_type=config.pair_type)
        if sav_res_name is not None:
            np.save(os.path.join(config.odir, 'predictions.npy'), result_dict)
        t2 = time.time()
        print('Total prediction time: {:0.3f}s'.format(t2 - t1))
        abs_err = eval_pipeline_with_ransac(result_dict,
                                            log,
                                            ransac_thres=config.ransac_thres,
                                            ransac_iter=10,
                                            ransac_miu=1.414,
                                            pair_type=config.pair_type,
                                            err_thres=err_thres,
                                            save_res_path=save_res_path)
        print('Total ransac time: {:0.3f}s, Total time: {:0.3f}s'.format(
            time.time() - t2,
            time.time() - t1))
    return abs_err
Ejemplo n.º 3
0
def train(net, config, log, train_loader, val_loaders=None):
    # Setup visdom monitor
    visdom = RelaPoseTmp(legend_tag=config.optim_tag,
                         viswin=config.viswin,
                         visenv=config.visenv,
                         vishost=config.vishost,
                         visport=config.visport)
    loss_meter, pos_acc_meter, rot_acc_meter = visdom.get_meters()

    print(
        'Start training from {config.start_epoch} to {config.epochs}.'.format(
            config=config))
    start_time = time.time()
    for epoch in range(config.start_epoch, config.epochs):
        net.train()  # Switch to training mode
        loss = net.train_epoch(train_loader, epoch)
        if (epoch + 1) % 2 == 0:
            lprint('Epoch {}, loss:{}'.format(epoch + 1, loss), log)
            loss_meter.update(X=epoch + 1, Y=loss)  # Update loss meter

        # Always save current training results to prevent lost training
        current_ckpt = {
            'last_epoch': epoch,
            'network': config.network,
            'state_dict': net.state_dict(),
            'optimizer': net.optimizer.state_dict()
        }
        torch.save(current_ckpt,
                   os.path.join(config.ckpt_dir, 'checkpoint.current.pth'))
        if config.validate and (epoch +
                                1) % config.validate == 0 and epoch > 0:
            # Evaluate on validation set
            abs_err = test(net, config, log, val_loaders)
            ckpt = {
                'last_epoch': epoch,
                'network': config.network,
                'state_dict': net.state_dict(),
                'optimizer': net.optimizer.state_dict()
            }
            ckpt_name = 'checkpoint_{epoch}_{abs_err[0]:.2f}m_{abs_err[1]:.2f}deg.pth'.format(
                epoch=(epoch + 1), abs_err=abs_err)
            torch.save(ckpt, os.path.join(config.ckpt_dir, ckpt_name))
            lprint('Save checkpoint: {}'.format(ckpt_name), log)

            # Update validation acc
            pos_acc_meter.update(X=epoch + 1, Y=abs_err[0])
            rot_acc_meter.update(X=epoch + 1, Y=abs_err[1])
        visdom.save_state()
    lprint('Total training time {0:.4f}s'.format((time.time() - start_time)),
           log)
Ejemplo n.º 4
0
def eval_pipeline_without_ransac(result_dict, err_thres=(2, 5), log=None):
    avg_rela_t_err = [
    ]  # Averge relative translation error in angle over datasets
    avg_rela_q_err = [
    ]  # Average relative roataion(quternion) error in angle over datasets
    avg_abs_c_dist_err = [
    ]  # Averge absolute position error in meter over datasets
    avg_abs_c_ang_err = [
    ]  # Averge absolute position error in angle over datasets
    avg_abs_q_err = [
    ]  # Averge absolute roataion(quternion) angle error over dataset

    for dataset in result_dict:
        pair_data = result_dict[dataset]['pair_data']
        lprint(
            '>>Testing dataset: {}, testing samples: {}'.format(
                dataset, len(pair_data)), log)

        # Calculate relative pose error
        rela_t_err, rela_q_err = cal_rela_pose_err(pair_data)
        avg_rela_t_err.append(rela_t_err)
        avg_rela_q_err.append(rela_q_err)

        # Calculate testing pose error with all training images
        abs_c_dist_err, abs_c_ang_err, abs_q_err, passed = cal_abs_pose_err(
            pair_data, err_thres)
        avg_abs_c_dist_err.append(abs_c_dist_err)
        avg_abs_c_ang_err.append(abs_c_ang_err)
        avg_abs_q_err.append(abs_q_err)

        lprint(
            'rela_err ({:.2f}deg, {:.2f}deg) abs err: ({:.2f}m, {:.2f}deg; {:.2f}deg, pass: {:.2f}%)'
            .format(rela_t_err, rela_q_err, abs_c_dist_err, abs_c_ang_err,
                    abs_q_err, passed), log)
    eval_val = (np.mean(avg_rela_t_err), np.mean(avg_rela_q_err),
                np.mean(avg_abs_c_dist_err), np.mean(avg_abs_c_ang_err),
                np.mean(avg_abs_q_err))
    lprint(
        '>>avg_rela_err ({eval_val[0]:.2f} deg, {eval_val[1]:.2f} deg) avg_abs_err ({eval_val[2]:.2f} m, {eval_val[3]:.2f}deg; {eval_val[4]:.2f}deg)'
        .format(eval_val=eval_val), log)
    return eval_val
Ejemplo n.º 5
0
def predict_essential_matrix(data_root, dataset, scenes,
                             pair_txt, train_lbl_txt, test_lbl_txt, 
                             colmap_db_root, db_name='database.db',
                             rthres=0.5, log=None):
    
    result_dict = {}
    print('Sift scenes: ',scenes)
    for scene in scenes:
        start_time = time.time()
        base_dir = os.path.join(data_root, dataset)
        db_base_dir = os.path.join(colmap_db_root, dataset)
        scene_dir = os.path.join(base_dir, scene)
        intrinsic_loader = get_camera_intrinsic_loader(base_dir, dataset, scene)
        
        # Load image pairs, pose and essential matrix labels
        abs_pose = parse_abs_pose_txt(os.path.join(scene_dir, train_lbl_txt))
        abs_pose.update(parse_abs_pose_txt(os.path.join(scene_dir, test_lbl_txt)))
        im_pairs = parse_matching_pairs(os.path.join(scene_dir, pair_txt)) # {(im1, im2) : (q, t, ess_mat)}
        pair_names = list(im_pairs.keys())      

        # Loading data from colmap database
        database_path = os.path.join(db_base_dir, scene, db_name)
        db_loader = COLMAPDataLoader(database_path)
        key_points = db_loader.load_keypoints(key_len=6)
        images = db_loader.load_images(name_based=True)   
        pair_ids = [(images[im1][0], images[im2][0]) for im1, im2 in pair_names]
        matches = db_loader.load_pair_matches(pair_ids)
        
        total_time = time.time() - start_time
        print('Scene {} Data loading finished, time:{:.4f}'.format(scene, total_time))
        
        # Calculate essential matrixs per query image
        total_num = len(im_pairs)
        pair_data = {}
        start_time = time.time() 
        no_pts_pairs = []
        for i, im_pair in enumerate(pair_names):
            train_im, test_im = im_pair
            # Dict to save results in a structured way for later RANSAC
            if test_im not in pair_data:
                pair_data[test_im] = {}
                pair_data[test_im]['test_pairs'] = []

            # Wrap pose label with RelaPose, AbsPose objects
            q, t, E_ = im_pairs[(train_im, test_im)]
            ctr, qtr = abs_pose[train_im]
            cte, qte = abs_pose[test_im]
            rela_pose_lbl = RelaPose(q, t)
            train_abs_pose = AbsPose(qtr, ctr, init_proj=True)
            test_abs_pose = AbsPose(qte, cte)
            pair_data[test_im]['test_abs_pose'] = test_abs_pose
            
            # Extract pt correspondences
            pair_id = pair_ids[i]         
            pts, invalid = extract_pair_pts(pair_id, key_points, matches)
            if invalid:
                no_pts_pairs.append(im_pair)
                # TODO: set the accuracy to very large?
                continue
             
            # Estimate essential matrix from pt correspondences and extract relative poses
            K = intrinsic_loader.get_relative_intrinsic_matrix(train_im, test_im)
            p1 = pts[:, 0:2]
            p2 = pts[:, 2:4]
            E, inliers = cv2.findEssentialMat(p1, p2, cameraMatrix=K, method=cv2.FM_RANSAC, threshold=rthres)

            # Wrap ess pair
            (t, R0, R1) = decompose_essential_matrix(E)
            test_pair = EssPair(test_im, train_im, train_abs_pose, rela_pose_lbl, t, R0, R1)
            if test_pair.is_invalid():
                # Invalid pairs that causes 'inf' due to bad retrieval and will corrupt ransac
                continue
                
            pair_data[test_im]['test_pairs'].append(test_pair) 
        total_time = time.time() - start_time
        lprint('Scene:{} Samples total:{} No correspondence pairs:{}. Time total:{:.4f} per_pair:{:.4f}'.format(scene, total_num, len(no_pts_pairs), total_time, total_time / (1.0 * total_num)), log)   
        result_dict[scene] = {}
        result_dict[scene]['pair_data'] = pair_data
        result_dict[scene]['no_pt_pairs'] = no_pts_pairs
    return result_dict
Ejemplo n.º 6
0
def predict_essential_matrix(data_root,
                             dataset,
                             data_loaders,
                             model,
                             k_size=2,
                             do_softmax=True,
                             rthres=4.0,
                             ncn_thres=0.9,
                             log=None):
    result_dict = {}
    for scene in data_loaders:
        result_dict[scene] = {}
        data_loader = data_loaders[scene]
        total_num = len(data_loader.dataset)  # Total image pair number
        base_dir = os.path.join(data_root, dataset)
        scene_dir = os.path.join(base_dir, scene)
        intrinsic_loader = get_camera_intrinsic_loader(base_dir, dataset,
                                                       scene)

        # Predict essential matrix over samples
        pair_data = {}
        start_time = time.time()
        for i, batch in enumerate(data_loader):
            # Load essnet images
            train_im_ref, test_im_ref = batch['im_pair_refs'][0][0], batch[
                'im_pair_refs'][1][0]
            train_im, test_im = batch['im_pairs']
            train_im = train_im.to(model.device)
            test_im = test_im.to(model.device)

            # Calculate correspondence score map
            with torch.no_grad():
                # Forward feature to ncn module
                if k_size > 1:
                    corr4d, delta4d = model.forward_corr4d(train_im, test_im)
                else:
                    corr4d, delta4d = model.forward_corr4d(train_im, test_im)
                    delta4d = None

            # Calculate matches
            xA, yA, xB, yB, score = cal_matches(corr4d,
                                                delta4d,
                                                k_size=k_size,
                                                do_softmax=do_softmax,
                                                matching_both_directions=True)

            # Scale matches to original pixel level
            w, h = intrinsic_loader.w, intrinsic_loader.h
            matches = np.dstack([xA * w, yA * h, xB * w,
                                 yB * h]).squeeze()  # N, 4
            K = intrinsic_loader.get_relative_intrinsic_matrix(
                train_im_ref, test_im_ref)

            # Find essential matrix
            inds = np.where(score > ncn_thres)[0]
            score = score[inds]
            matches = matches[inds, :]
            p1 = matches[:, 0:2]
            p2 = matches[:, 2:4]
            E, inliers = cv2.findEssentialMat(p1,
                                              p2,
                                              cameraMatrix=K,
                                              method=cv2.FM_RANSAC,
                                              threshold=rthres)

            # Dict to saving results in a structured way for later RANSAC
            if test_im_ref not in pair_data:
                pair_data[test_im_ref] = {}
                pair_data[test_im_ref]['test_pairs'] = []
                pair_data[test_im_ref]['cv_inliers'] = []

            # For debugging
            inliers = np.nonzero(inliers)[0]
            inlier_ratio = len(inliers) / p1.shape[0]
            pair_data[test_im_ref]['cv_inliers'].append(inlier_ratio)

            # Wrap pose label with RelaPose, AbsPose objects
            rela_pose_lbl = RelaPose(batch['relv_q'][0].data.numpy(),
                                     batch['relv_t'][0].data.numpy())
            train_abs_pose = AbsPose(batch['train_abs_q'][0].data.numpy(),
                                     batch['train_abs_c'][0].data.numpy(),
                                     init_proj=True)
            test_abs_pose = AbsPose(batch['test_abs_q'][0].data.numpy(),
                                    batch['test_abs_c'][0].data.numpy())
            pair_data[test_im_ref]['test_abs_pose'] = test_abs_pose

            # Wrap ess pair
            (t, R0, R1) = decompose_essential_matrix(E)
            test_pair = EssPair(test_im_ref, train_im_ref, train_abs_pose,
                                rela_pose_lbl, t, R0, R1)
            if test_pair.is_invalid():
                # Invalid pairs that causes 'inf' due to bad retrieval and will corrupt ransac
                continue
            pair_data[test_im_ref]['test_pairs'].append(test_pair)
        total_time = time.time() - start_time
        lprint(
            'Scene:{} num_samples:{} total_time:{:.4f} time_per_pair:{:.4f}'.
            format(scene, total_num, total_time,
                   total_time / (1.0 * total_num)), log)
        result_dict[scene]['pair_data'] = pair_data
    return result_dict
Ejemplo n.º 7
0
def main():
    parser = argparse.ArgumentParser(description='Eval immatch model')
    parser.add_argument('--data_root',
                        '-droot',
                        type=str,
                        default='data/datasets_original/')
    parser.add_argument('--dataset',
                        '-ds',
                        type=str,
                        default='CambridgeLandmarks',
                        choices=['CambridgeLandmarks', '7Scenes'])
    parser.add_argument('--scenes', '-sc', type=str, nargs='*', default=None)
    parser.add_argument('--image_size', '-imsize', type=int, default=None)
    parser.add_argument('--scale_pts', action='store_true')
    parser.add_argument('--pair_txt',
                        '-pair',
                        type=str,
                        default='test_pairs.5nn.300cm50m.vlad.minmax.txt')
    parser.add_argument('--ckpt_dir', '-cdir', type=str, default=None)
    parser.add_argument('--ckpt_name', '-ckpt', type=str, default=None)
    parser.add_argument('--feat', '-feat', type=str, default=None)
    parser.add_argument('--ncn', '-ncn', type=str, default=None)
    parser.add_argument('--gpu', '-gpu', type=int, default=0)
    parser.add_argument('--cv_ransac_thres',
                        type=float,
                        nargs='*',
                        default=[4.0])
    parser.add_argument('--loc_ransac_thres',
                        type=float,
                        nargs='*',
                        default=[15])
    parser.add_argument('--ncn_thres', type=float, default=0.9)
    parser.add_argument('--posfix', type=str, default='imagenet+ncn')
    parser.add_argument('--out_dir',
                        '-o',
                        type=str,
                        default='output/ncmatch_5pt/loc_results')

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    # Data Loading
    data_loaders = get_test_loaders(image_size=args.image_size,
                                    dataset=args.dataset,
                                    scenes=args.scenes,
                                    pair_txt=args.pair_txt,
                                    data_root=args.data_root)

    out_dir = args.out_dir
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    out_path = os.path.join(out_dir, '{}.txt'.format(args.posfix))
    log = open(out_path, 'a')
    lprint('Log results {} posfix: {}'.format(out_path, args.posfix))

    if args.ckpt_dir is None:
        ckpts = [None]
    elif args.ckpt_name:
        # Load the target checkpoint
        ckpts = [os.path.join(args.ckpt_dir, args.ckpt_name)]
    else:
        # Load all checkpoints under ckpt_dir
        ckpts = glob.glob(os.path.join(args.ckpt_dir, 'checkpoint*'))

    print('Start evaluation, ckpt to eval: {}'.format(len(ckpts)))
    for ckpt in ckpts:
        # Load models
        lprint(
            '\n\n>>>Eval ImMatchNet:pair_txt: {}\nckpt {} \nfeat {} \nncn {}'.
            format(args.pair_txt, ckpt, args.feat, args.ncn), log)
        config = NCMatchEvalConfig(weights_dir=ckpt,
                                   feat_weights=args.feat,
                                   ncn_weights=args.ncn,
                                   early_feat=True)
        matchnet = NCMatchNet(config)

        for rthres in args.cv_ransac_thres:
            lprint(
                '\n>>>>cv_ansac : {} ncn_thres: {}'.format(
                    rthres, args.ncn_thres), log)
            result_dict = predict_essential_matrix(args.data_root,
                                                   args.dataset,
                                                   data_loaders,
                                                   matchnet,
                                                   do_softmax=True,
                                                   rthres=rthres,
                                                   ncn_thres=args.ncn_thres,
                                                   log=log)

            np.save(
                os.path.join(
                    out_dir,
                    'preds.cvr{}_ncn{}.{}.npy'.format(rthres, args.ncn_thres,
                                                      args.posfix)),
                result_dict)
            eval_pipeline_with_ransac(result_dict,
                                      log,
                                      ransac_thres=args.loc_ransac_thres,
                                      ransac_iter=10,
                                      ransac_miu=1.414,
                                      pair_type='ess',
                                      err_thres=[(0.25, 2), (0.5, 5), (5, 10)],
                                      save_res_path=None)
    log.close()
def eval_prediction(data_loaders, net, log, pair_type='ess'):
    '''The function evaluates network predictions and prepare results for RANSAC and convenient accuracy calculation.
    Args:
        - data_loaders: a dict of data loaders(torch.utils.data.DataLoader) for all datasets.
        - net: the network model to be tested.
        - model_type: specifies the type network prediction. Options: 'ess' or 'relapose'.
        - log: path of the txt file for logging
    Return:
       A data dictionary that is designed especially to involve all necessary relative and absolute pose information required
       by the RANSAC algorithm and error metric accuracy calculations, since the precalculation can drastically improve the running time. 
       Basically, a key in the data dict is a test_im and the corresponding value is a sub dict. The sub dict involves:
       - key:'test_abs_pos', value: the absolute pose label of the test_im (for absolute pose accuracy calculation)
       - key:'test_pairs', value: a list of pair data linked to the test_im. The pair data are encapsulated inside either
       a RelaPosePair object or an EssPair object depending on the type of regression results.
    '''
    result_dict = {}
    for dataset in data_loaders:
        start_time = time.time()
        data_loader = data_loaders[dataset]
        total_num = len(data_loaders[dataset].dataset)
        pair_data = {}
        result_dict[dataset] = {}
        for j, batch in enumerate(data_loader):
            pred_vec = net.predict_(batch)

            # Calculate poses per query image
            for i, test_im in enumerate(batch['im_pair_refs'][1]):
                if test_im not in pair_data:
                    pair_data[test_im] = {}
                    pair_data[test_im]['test_pairs'] = []

                # Wrap pose label with RelaPose, AbsPose objects
                rela_pose_lbl = RelaPose(batch['relv_q'][i].data.numpy(),
                                         batch['relv_t'][i].data.numpy())
                train_abs_pose = AbsPose(batch['train_abs_q'][i].data.numpy(),
                                         batch['train_abs_c'][i].data.numpy(),
                                         init_proj=True)
                test_abs_pose = AbsPose(batch['test_abs_q'][i].data.numpy(),
                                        batch['test_abs_c'][i].data.numpy())
                pair_data[test_im]['test_abs_pose'] = test_abs_pose

                # Estimate relative pose
                if pair_type == 'ess':
                    # Pose prediction are essential matrix
                    E = pred_vec[i].cpu().data.numpy().reshape((3, 3))
                    (t, R0, R1) = decompose_essential_matrix(E)
                    train_im = batch['im_pair_refs'][0][i]
                    test_pair = EssPair(test_im, train_im, train_abs_pose,
                                        rela_pose_lbl, t, R0, R1)

                if pair_type == 'relapose':
                    # Pose prediction is (t, q)
                    rela_pose_pred = RelaPose(
                        pred_vec[1][i].cpu().data.numpy(),
                        pred_vec[0][i].cpu().data.numpy())
                    test_pair = RelaPosePair(test_im, train_abs_pose,
                                             rela_pose_lbl, rela_pose_pred)
                pair_data[test_im]['test_pairs'].append(test_pair)
        total_time = time.time() - start_time
        lprint(
            'Dataset:{} num_samples:{} total_time:{:.4f} time_per_pair:{:.4f}'.
            format(dataset, total_num, total_time,
                   total_time / (1.0 * total_num)), log)
        result_dict[dataset]['pair_data'] = pair_data
    return result_dict
Ejemplo n.º 9
0
def eval_pipeline_with_ransac(result_dict,
                              log,
                              ransac_thres,
                              ransac_iter,
                              ransac_miu,
                              pair_type,
                              err_thres,
                              save_res_path=None):
    lprint(
        '>>>>Evaluate model with Ransac(iter={}, miu={}) Error thres:{})'.
        format(ransac_iter, ransac_miu, err_thres), log)
    t1 = time.time()
    best_abs_err = None  # TODO: not used for now, remove it in the end
    for thres in ransac_thres:
        avg_err = []
        avg_pass = []
        lprint('\n>>Ransac threshold:{}'.format(thres), log)
        loc_results_dict = {}
        for dataset in result_dict:
            start_time = time.time()
            pair_data = result_dict[dataset]['pair_data']
            loc_results_dict[dataset] = {} if save_res_path else None
            if pair_type == 'angess':  # Since angles have been converted to relative poses
                pair_type = 'relapose'
            tested_num, approx_queries, pass_rate, err_res = ransac(
                pair_data,
                thres,
                in_iter=ransac_iter,
                pair_type=pair_type,
                err_thres=err_thres,
                loc_results=loc_results_dict[dataset])
            avg_err.append(err_res)
            avg_pass.append(pass_rate)
            total_time = time.time() - start_time
            dataset_pr_len = min(10, len(dataset))
            lprint(
                'Dataset:{dataset} Bad/All:{approx_num}/{tested_num}, Rela:({err_res[0]:.2f}deg,{err_res[1]:.2f}deg) Abs:({err_res[2]:.2f}m,{err_res[4]:.2f}deg)/{err_res[3]:.2f}deg) Pass:'******'/'.join('{:.2f}%'.format(v)
                                                    for v in pass_rate), log)

        avg_err = tuple(np.mean(avg_err, axis=0))
        avg_pass = tuple(np.mean(
            avg_pass, axis=0)) if len(err_thres) > 1 else tuple(avg_pass)
        if best_abs_err is not None:
            if best_abs_err[0] < avg_err[2]:
                best_abs_err = (avg_err[2], avg_err[4])
        else:
            best_abs_err = (avg_err[2], avg_err[4])
        lprint(
            'Avg: Rela:({err_res[0]:.2f}deg,{err_res[1]:.2f}deg) Abs:({err_res[2]:.2f}m,{err_res[4]:.2f}deg;{err_res[3]:.2f}deg) Pass:'******'/'.join('{:.2f}%'.format(v)
                                                for v in avg_pass), log)

        if save_res_path:
            np.save(save_res_path, loc_results_dict)
    time_stamp = 'Ransac testing time: {}s\n'.format(time.time() - t1)
    lprint(time_stamp, log)
    return best_abs_err
Ejemplo n.º 10
0
def main():
    # Parse configuration
    config = RelaPoseConfig().parse()
    log = open(config.log, 'a')
    net = networks.__dict__[config.network](config)

    # Training/ Testing Datasets
    if config.training:
        lprint(config2str(config), log)
        datasets = get_datasets(datasets=config.datasets,
                                pair_txt=config.pair_txt,
                                data_root=config.data_root,
                                incl_sces=config.incl_sces,
                                ops=config.ops,
                                train_lbl_txt=config.train_lbl_txt,
                                test_lbl_txt=config.test_lbl_txt,
                                with_ess=config.with_ess,
                                with_virtual_pts=config.with_virtual_pts)
        train_set = data.ConcatDataset(datasets)
        lprint(
            'Concat training datasets total samples: {}'.format(
                len(train_set)), log)
        train_loader = data.DataLoader(train_set,
                                       batch_size=config.batch_size,
                                       shuffle=True,
                                       num_workers=config.num_workers,
                                       worker_init_fn=make_deterministic(
                                           config.seed))
        val_loaders = None
        if config.validate:
            val_loaders = {}
            val_sets = get_datasets(datasets=config.datasets,
                                    pair_txt=config.val_pair_txt,
                                    data_root=config.data_root,
                                    incl_sces=config.incl_sces,
                                    ops=config.val_ops,
                                    train_lbl_txt=config.train_lbl_txt,
                                    test_lbl_txt=config.test_lbl_txt,
                                    with_ess=False)
            for val_set in val_sets:
                val_loaders[val_set.scene] = data.DataLoader(
                    val_set,
                    batch_size=config.batch_size,
                    shuffle=False,
                    num_workers=config.num_workers,
                    worker_init_fn=make_deterministic(config.seed))
        train(net, config, log, train_loader, val_loaders)
    else:
        lprint(config2str(config))
        lprint('----------------------------------------------\n', log)
        lprint('>>Load weights dict {}'.format(config.weights_dir), log)
        lprint('>>Testing pairs: {}'.format(config.pair_txt), log)

        test_loaders = {}
        test_sets = get_datasets(datasets=config.datasets,
                                 pair_txt=config.pair_txt,
                                 data_root=config.data_root,
                                 incl_sces=config.incl_sces,
                                 ops=config.ops,
                                 train_lbl_txt=config.train_lbl_txt,
                                 test_lbl_txt=config.test_lbl_txt,
                                 with_ess=False)
        for test_set in test_sets:
            lprint(
                'Testing scene {} samples: {}'.format(test_set.scene,
                                                      len(test_set)), log)
            test_loaders[test_set.scene] = data.DataLoader(
                test_set,
                batch_size=config.batch_size,
                shuffle=False,
                num_workers=config.num_workers,
                worker_init_fn=make_deterministic(config.seed))
        test(net, config, log, test_loaders)
    log.close()