Пример #1
0
def get_data(args):
    root = osp.join(args.data_dir, 'tokyo')
    dataset = datasets.create('tokyo', root)

    test_transformer_db = get_transformer_test(args.height, args.width)
    test_transformer_q = get_transformer_test(args.height,
                                              args.width,
                                              tokyo=True)

    test_loader_q = DataLoader(Preprocessor(dataset.q_test,
                                            root=dataset.images_dir,
                                            transform=test_transformer_q),
                               batch_size=1,
                               num_workers=args.workers,
                               sampler=DistributedSliceSampler(dataset.q_test),
                               shuffle=False,
                               pin_memory=True)

    test_loader_db = DataLoader(Preprocessor(dataset.db_test,
                                             root=dataset.images_dir,
                                             transform=test_transformer_db),
                                batch_size=args.test_batch_size,
                                num_workers=args.workers,
                                sampler=DistributedSliceSampler(
                                    dataset.db_test),
                                shuffle=False,
                                pin_memory=True)

    return dataset, test_loader_q, test_loader_db
Пример #2
0
def get_data(args, iters):
    root = osp.join(args.data_dir, args.dataset)
    dataset = datasets.create(args.dataset, root, scale=args.scale)

    train_transformer = get_transformer_train(args.height, args.width)
    test_transformer = get_transformer_test(args.height, args.width)

    sampler = DistributedRandomTupleSampler(dataset.q_train,
                                            dataset.db_train,
                                            dataset.train_pos,
                                            dataset.train_neg,
                                            neg_num=args.neg_num,
                                            neg_pool=args.neg_pool)
    train_loader = IterLoader(DataLoader(Preprocessor(
        dataset.q_train + dataset.db_train,
        root=dataset.images_dir,
        transform=train_transformer),
                                         batch_size=args.tuple_size,
                                         num_workers=args.workers,
                                         sampler=sampler,
                                         shuffle=False,
                                         pin_memory=True,
                                         drop_last=True),
                              length=iters)

    train_extract_loader = DataLoader(
        Preprocessor(sorted(list(set(dataset.q_train)
                                 | set(dataset.db_train))),
                     root=dataset.images_dir,
                     transform=test_transformer),
        batch_size=args.test_batch_size,
        num_workers=args.workers,
        sampler=DistributedSliceSampler(
            sorted(list(set(dataset.q_train) | set(dataset.db_train)))),
        shuffle=False,
        pin_memory=True)

    val_loader = DataLoader(
        Preprocessor(sorted(list(set(dataset.q_val) | set(dataset.db_val))),
                     root=dataset.images_dir,
                     transform=test_transformer),
        batch_size=args.test_batch_size,
        num_workers=args.workers,
        sampler=DistributedSliceSampler(
            sorted(list(set(dataset.q_val) | set(dataset.db_val)))),
        shuffle=False,
        pin_memory=True)

    test_loader = DataLoader(
        Preprocessor(sorted(list(set(dataset.q_test) | set(dataset.db_test))),
                     root=dataset.images_dir,
                     transform=test_transformer),
        batch_size=args.test_batch_size,
        num_workers=args.workers,
        sampler=DistributedSliceSampler(
            sorted(list(set(dataset.q_test) | set(dataset.db_test)))),
        shuffle=False,
        pin_memory=True)

    return dataset, train_loader, val_loader, test_loader, sampler, train_extract_loader
Пример #3
0
def get_data(args):
    root = osp.join(args.data_dir, args.dataset)
    dataset = datasets.create(args.dataset, root, scale=args.scale)

    test_transformer_db = get_transformer_test(args.height, args.width)
    test_transformer_q = get_transformer_test(args.height,
                                              args.width,
                                              tokyo=(args.dataset == 'tokyo'))

    pitts = datasets.create('pitts',
                            osp.join(args.data_dir, 'pitts'),
                            scale='30k',
                            verbose=False)
    pitts_train = sorted(list(set(pitts.q_train) | set(pitts.db_train)))
    train_extract_loader = DataLoader(
        Preprocessor(pitts_train,
                     root=pitts.images_dir,
                     transform=test_transformer_db),
        batch_size=args.test_batch_size,
        num_workers=args.workers,
        sampler=DistributedSliceSampler(pitts_train),
        shuffle=False,
        pin_memory=True)

    test_loader_q = DataLoader(
        Preprocessor(dataset.q_test,
                     root=dataset.images_dir,
                     transform=test_transformer_q),
        batch_size=(1 if args.dataset == 'tokyo' else args.test_batch_size),
        num_workers=args.workers,
        sampler=DistributedSliceSampler(dataset.q_test),
        shuffle=False,
        pin_memory=True)

    test_loader_db = DataLoader(Preprocessor(dataset.db_test,
                                             root=dataset.images_dir,
                                             transform=test_transformer_db),
                                batch_size=args.test_batch_size,
                                num_workers=args.workers,
                                sampler=DistributedSliceSampler(
                                    dataset.db_test),
                                shuffle=False,
                                pin_memory=True)

    return dataset, pitts_train, train_extract_loader, test_loader_q, test_loader_db