예제 #1
0
    normTransform = transforms.Normalize(mean_vector, std_vector)

    if not args.pre_loaded_training_dataset:
        # training dataset, created on the fly at each epoch

        # training data, big data (520,520) rescaled to 256x256 to fit the fixed input of network,
        # then pre-processing is applied here(whereas in GLUNet, it is within the function)
        source_transforms = transforms.Compose([transforms.ToPILImage(),
                                                transforms.Resize(256),
                                                transforms.ToTensor(),
                                                normTransform])
        pyramid_param = [256] # means that we get the ground-truth flow field at this size
        train_dataset = HomoAffTps_Dataset(image_path=args.training_data_dir,
                                           csv_file=osp.join('datasets', 'csv_files',
                                                         'homo_aff_tps_train_DPED_CityScape_ADE.csv'),
                                           transforms=source_transforms,
                                           transforms_target=source_transforms,
                                           pyramid_param=pyramid_param,
                                           get_flow=True,
                                           output_size=(520, 520))

        # validation dataset
        pyramid_param = [256]
        val_dataset = HomoAffTps_Dataset(image_path=args.evaluation_data_dir,
                                         csv_file=osp.join('datasets', 'csv_files',
                                                           'homo_aff_tps_test_DPED_CityScape_ADE.csv'),
                                         transforms=source_transforms,
                                         transforms_target=source_transforms,
                                         pyramid_param=pyramid_param,
                                         get_flow=True,
                                         output_size=(520, 520))
        os.makedirs(image_dir)
    if not os.path.exists(flow_dir):
        os.makedirs(flow_dir)

    # datasets
    source_img_transforms = transforms.Compose(
        [ArrayToTensor(get_float=False)])
    target_img_transforms = transforms.Compose(
        [ArrayToTensor(get_float=False)])
    pyramid_param = [520]

    # training dataset
    train_dataset = HomoAffTps_Dataset(image_path=args.image_data_path,
                                       csv_file=args.csv_path,
                                       transforms=source_img_transforms,
                                       transforms_target=target_img_transforms,
                                       pyramid_param=pyramid_param,
                                       get_flow=True,
                                       output_size=(520, 520))

    test_dataloader = DataLoader(train_dataset,
                                 batch_size=1,
                                 shuffle=False,
                                 num_workers=1)

    pbar = tqdm(enumerate(test_dataloader), total=len(test_dataloader))
    for i, minibatch in pbar:
        image_source = minibatch['source_image']  # shape is 1x3xHxW
        image_target = minibatch['target_image']
        if image_source.shape[1] == 3:
            image_source = image_source.permute(0, 2, 3,