Пример #1
0
# define the image processing parameters, the actual pre-processing is done within the model functions
input_images_transform = transforms.Compose([ArrayToTensor(get_float=False)
                                             ])  # only put channel first
gt_flow_transform = transforms.Compose([ArrayToTensor()
                                        ])  # only put channel first
co_transform = None

for pre_trained_model_type in pre_trained_models:
    print(pre_trained_model_type)
    with torch.no_grad():

        # define the network to use
        if args.model == 'GLUNet':
            network = GLU_Net(model_type=pre_trained_model_type,
                              consensus_network=False,
                              cyclic_consistency=True,
                              iterative_refinement=True,
                              apply_flipping_condition=args.flipping_condition)

        elif args.model == 'SemanticGLUNet':
            network = GLU_Net(model_type=pre_trained_model_type,
                              feature_concatenation=True,
                              cyclic_consistency=False,
                              consensus_network=True,
                              iterative_refinement=True,
                              apply_flipping_condition=args.flipping_condition)

        elif args.model == 'GLOCALNet':
            network = GLOCAL_Net(model_type=pre_trained_model_type,
                                 constrained_corr=True,
                                 global_corr=True)
torch.set_grad_enabled(False)  # make sure to not compute gradients for computational performance
torch.backends.cudnn.enabled = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # either gpu or cpu
dataset_name = 'CPC'
pairs_which_dataset = pd.read_csv(dataset_name + '/pairs_which_dataset.txt')
pairs_gts = np.loadtxt(dataset_name + '/pairs_with_gt.txt')
pairs_gts = pd.DataFrame(pairs_gts)
pairs = pairs_gts[pairs_gts.columns[0:2]]
l_pairs = pairs[pairs.columns[0]]
r_pairs = pairs[pairs.columns[1]]
size = np.size(pairs_gts[0])
matches = {'Matches': []}
with torch.no_grad():
    network = GLU_Net(path_pre_trained_models=args.pre_trained_models_dir,
                      model_type=args.pre_trained_model,
                      consensus_network=False,
                      cyclic_consistency=True,
                      iterative_refinement=True,
                      apply_flipping_condition=False)
    for idx in range(size):
        print(idx)
        l = "{:08}".format(int(l_pairs[idx]))
        r = "{:08}".format(int(r_pairs[idx]))
        folder = pairs_which_dataset.iloc[idx]
        I1 = dataset_name + '/' + folder[0] + 'Images/' + l + '.jpg'
        I2 = dataset_name + '/' + folder[0] + 'Images/' + r + '.jpg'
        # I2 = imageio.imread(pairs_which_dataset(idx) + 'Images/' sprintf('%.8d.jpg', r)]);

        try:
            source_image = imageio.imread(I1)
            target_image = imageio.imread(I2)
            source_image, target_image = pad_to_same_shape(source_image, target_image)