# define the image processing parameters, the actual pre-processing is done within the model functions input_images_transform = transforms.Compose([ArrayToTensor(get_float=False) ]) # only put channel first gt_flow_transform = transforms.Compose([ArrayToTensor() ]) # only put channel first co_transform = None for pre_trained_model_type in pre_trained_models: print(pre_trained_model_type) with torch.no_grad(): # define the network to use if args.model == 'GLUNet': network = GLU_Net(model_type=pre_trained_model_type, consensus_network=False, cyclic_consistency=True, iterative_refinement=True, apply_flipping_condition=args.flipping_condition) elif args.model == 'SemanticGLUNet': network = GLU_Net(model_type=pre_trained_model_type, feature_concatenation=True, cyclic_consistency=False, consensus_network=True, iterative_refinement=True, apply_flipping_condition=args.flipping_condition) elif args.model == 'GLOCALNet': network = GLOCAL_Net(model_type=pre_trained_model_type, constrained_corr=True, global_corr=True)
torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance torch.backends.cudnn.enabled = True device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # either gpu or cpu dataset_name = 'CPC' pairs_which_dataset = pd.read_csv(dataset_name + '/pairs_which_dataset.txt') pairs_gts = np.loadtxt(dataset_name + '/pairs_with_gt.txt') pairs_gts = pd.DataFrame(pairs_gts) pairs = pairs_gts[pairs_gts.columns[0:2]] l_pairs = pairs[pairs.columns[0]] r_pairs = pairs[pairs.columns[1]] size = np.size(pairs_gts[0]) matches = {'Matches': []} with torch.no_grad(): network = GLU_Net(path_pre_trained_models=args.pre_trained_models_dir, model_type=args.pre_trained_model, consensus_network=False, cyclic_consistency=True, iterative_refinement=True, apply_flipping_condition=False) for idx in range(size): print(idx) l = "{:08}".format(int(l_pairs[idx])) r = "{:08}".format(int(r_pairs[idx])) folder = pairs_which_dataset.iloc[idx] I1 = dataset_name + '/' + folder[0] + 'Images/' + l + '.jpg' I2 = dataset_name + '/' + folder[0] + 'Images/' + r + '.jpg' # I2 = imageio.imread(pairs_which_dataset(idx) + 'Images/' sprintf('%.8d.jpg', r)]); try: source_image = imageio.imread(I1) target_image = imageio.imread(I2) source_image, target_image = pad_to_same_shape(source_image, target_image)