示例#1
0
def build_cls_test_loader(cfg, dataset_name, mapper=None, **kwargs):
    cfg = cfg.clone()

    dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.query

    if mapper is not None:
        transforms = mapper
    else:
        transforms = build_transforms(cfg, is_train=False)

    test_set = CommDataset(test_items, transforms, relabel=False)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader
示例#2
0
    def build_test_loader(cls, cfg, test_set):
        logger = logging.getLogger('fastreid')
        logger.info("Prepare testing loader")

        # test_loader = DataLoader(
        #     # Preprocessor(test_set),
        #     test_set,
        #     batch_size=cfg.TEST.IMS_PER_BATCH,
        #     num_workers=cfg.DATALOADER.NUM_WORKERS,
        #     shuffle=False,
        #     pin_memory=True,
        # )

        test_batch_size = cfg.TEST.IMS_PER_BATCH
        mini_batch_size = test_batch_size // comm.get_world_size()
        num_workers = cfg.DATALOADER.NUM_WORKERS
        data_sampler = samplers.InferenceSampler(len(test_set))
        batch_sampler = BatchSampler(data_sampler, mini_batch_size, False)
        test_loader = DataLoaderX(
            comm.get_local_rank(),
            dataset=test_set,
            batch_sampler=batch_sampler,
            num_workers=num_workers,  # save some memory
            collate_fn=fast_batch_collator,
            pin_memory=True,
        )

        return test_loader
示例#3
0
def build_attr_test_loader(cfg, dataset_name):
    cfg = cfg.clone()
    cfg.defrost()

    dataset = DATASET_REGISTRY.get(dataset_name)(
        root=_root, combineall=cfg.DATASETS.COMBINEALL)
    if comm.is_main_process():
        dataset.show_test()
    test_items = dataset.test

    test_transforms = build_transforms(cfg, is_train=False)
    test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader
示例#4
0
def build_face_test_loader(cfg, dataset_name, **kwargs):
    dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs)
    if comm.is_main_process():
        dataset.show_test()

    test_set = FaceCommDataset(dataset.carray, dataset.is_same)

    mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size()
    data_sampler = samplers.InferenceSampler(len(test_set))
    batch_sampler = torch.utils.data.BatchSampler(data_sampler,
                                                  mini_batch_size, False)
    test_loader = DataLoader(
        test_set,
        batch_sampler=batch_sampler,
        num_workers=4,  # save some memory
        collate_fn=fast_batch_collator,
        pin_memory=True,
    )
    return test_loader, test_set.labels