def eval_KITTI(method, args): dset = KITTIDataset( root=kitti_dir, split='test', descriptor=args.descriptor, in_dim=6, inlier_threshold=args.inlier_threshold, num_node=15000, use_mutual=args.use_mutual, augment_axis=0, augment_rotation=0.00, augment_translation=0.0, ) dloader = get_dataloader(dset, batch_size=1, num_workers=8, shuffle=False) stats = eval_KITTI_scene(method, dloader, args) # pair level average allpair_stats = stats allpair_average = allpair_stats.mean(0) correct_pair_average = allpair_stats[allpair_stats[:, 0] == 1].mean(0) logging.info(f"*" * 40) logging.info( f"All {allpair_stats.shape[0]} pairs, Mean Reg Recall={allpair_average[0] * 100:.2f}%, Mean Re={correct_pair_average[1]:.2f}, Mean Te={correct_pair_average[2]:.2f}" ) logging.info( f"\tInput: Mean Inlier Num={allpair_average[3]:.2f}(ratio={allpair_average[4] * 100:.2f}%)" ) logging.info( f"\tOutput: Mean Inlier Num={allpair_average[5]:.2f}(precision={allpair_average[6] * 100:.2f}%, recall={allpair_average[7] * 100:.2f}%, f1={allpair_average[8] * 100:.2f}%)" ) logging.info( f"\tMean model time: {allpair_average[9]:.2f}s, Mean data time: {allpair_average[10]:.2f}s" ) return allpair_stats
def eval_KITTI(model, config, use_icp): dset = KITTIDataset(root='/data/KITTI', split='test', descriptor=config.descriptor, in_dim=config.in_dim, inlier_threshold=config.inlier_threshold, num_node=12000, use_mutual=config.use_mutual, augment_axis=0, augment_rotation=0.00, augment_translation=0.0, ) dloader = get_dataloader(dset, batch_size=1, num_workers=16, shuffle=False) stats = eval_KITTI_per_pair(model, dloader, config, use_icp) logging.info(f"Max memory allicated: {torch.cuda.max_memory_allocated() / 1024 ** 3:.2f}GB") # pair level average allpair_stats = stats allpair_average = allpair_stats.mean(0) correct_pair_average = allpair_stats[allpair_stats[:, 0] == 1].mean(0) logging.info(f"*"*40) logging.info(f"All {allpair_stats.shape[0]} pairs, Mean Success Rate={allpair_average[0]*100:.2f}%, Mean Re={correct_pair_average[1]:.2f}, Mean Te={correct_pair_average[2]:.2f}") logging.info(f"\tInput: Mean Inlier Num={allpair_average[3]:.2f}(ratio={allpair_average[4]*100:.2f}%)") logging.info(f"\tOutput: Mean Inlier Num={allpair_average[5]:.2f}(precision={allpair_average[6]*100:.2f}%, recall={allpair_average[7]*100:.2f}%, f1={allpair_average[8]*100:.2f}%)") logging.info(f"\tMean model time: {allpair_average[9]:.2f}s, Mean data time: {allpair_average[10]:.2f}s") return allpair_stats
def calculate_repeatability(desc_name, timestr, num_keypts): """ calculate the relative repeatability of {desc_name}_{timestr} under {num_keypts} setting. """ from datasets.KITTI import KITTIDataset dataset = KITTIDataset(1, first_subsampling_dl=0.3, load_test=True) repeat_list = [] keyptspath = f"geometric_registration_kitti/{desc_name}_{timestr}" for i in range(len(dataset.files['test'])): drive = dataset.files['test'][i][0] t0, t1 = dataset.files['test'][i][1], dataset.files['test'][i][2] filename = f'{drive}@{t0}-{t1}.npz' if not os.path.exists(os.path.join(keyptspath, filename)): continue data = np.load(os.path.join(keyptspath, filename)) source_keypts = data['anc_pts'][-num_keypts:] target_keypts = data['pos_pts'][-num_keypts:] repeat_list += [ deal_with_one_pair(source_keypts, target_keypts, data['trans'], num_keypts, threshold=0.5) ] print( f"Average Repeatability at num_keypts = {num_keypts}: {np.mean(repeat_list)}" ) return np.mean(repeat_list)
########################### # Load the model parameters ########################### config = KITTIConfig() ############## # Prepare Data ############## print() print('Dataset Preparation') print('*******************') # Initiate dataset configuration dataset = KITTIDataset(config.input_threads, config.first_subsampling_dl) # Create subsampled input clouds dl0 = config.first_subsampling_dl # dataset.load_subsampled_clouds(dl0) # Initialize input pipelines dataset.init_input_pipeline(config) # Test the input pipeline alone with this debug function # dataset.check_input_pipeline_timing(config) ############## # Define Model ##############
def test_caller(path, step_ind, on_val): # Disable warnings os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' ########################### # Load the model parameters ########################### # Load model parameters config = Config() config.load(path) ################################## # Change model parameters for test ################################## # Change parameters for the test here. For example, you can stop augmenting the input data. #config.augment_noise = 0.0001 #config.augment_color = 1.0 #config.validation_size = 500 #config.batch_num = 10 ############## # Prepare Data ############## print() print('Dataset Preparation') print('*******************') # Initiate dataset configuration dataset = KITTIDataset(1, config.first_subsampling_dl, load_test=True) # Initialize input pipelines dataset.init_test_input_pipeline(config) ############## # Define Model ############## print('Creating Model') print('**************\n') t1 = time.time() model = KernelPointFCNN(dataset.flat_inputs, config) # Find all snapshot in the chosen training folder snap_path = os.path.join(path, 'snapshots') snap_steps = [ int(f[:-5].split('-')[-1]) for f in os.listdir(snap_path) if f[-5:] == '.meta' ] # Find which snapshot to restore chosen_step = np.sort(snap_steps)[step_ind] chosen_snap = os.path.join(path, 'snapshots', 'snap-{:d}'.format(chosen_step)) # Create a tester class tester = ModelTester(model, restore_snap=chosen_snap) t2 = time.time() print('\n----------------') print('Done in {:.1f} s'.format(t2 - t1)) print('----------------\n') ############ # Start test ############ print('Start Test') print('**********\n') tester.test_kitti(model, dataset)
# momentum=config.momentum, weight_decay=config.weight_decay, ) config.scheduler = optim.lr_scheduler.ExponentialLR( config.optimizer, gamma=config.scheduler_gamma, ) # create dataset and dataloader train_set = KITTIDataset( root=config.root, split='train', descriptor=config.descriptor, in_dim=config.in_dim, inlier_threshold=config.inlier_threshold, num_node=config.num_node, use_mutual=config.use_mutual, augment_axis=config.augment_axis, augment_rotation=config.augment_rotation, augment_translation=config.augment_translation, ) val_set = KITTIDataset( root=config.root, split='val', descriptor=config.descriptor, in_dim=config.in_dim, inlier_threshold=config.inlier_threshold, num_node=config.num_node, use_mutual=config.use_mutual, augment_axis=config.augment_axis, augment_rotation=config.augment_rotation,