def build_cls_test_loader(cfg, dataset_name, mapper=None, **kwargs): cfg = cfg.clone() dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs) if comm.is_main_process(): dataset.show_test() test_items = dataset.query if mapper is not None: transforms = mapper else: transforms = build_transforms(cfg, is_train=False) test_set = CommDataset(test_items, transforms, relabel=False) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_test_loader(cls, cfg, test_set): logger = logging.getLogger('fastreid') logger.info("Prepare testing loader") # test_loader = DataLoader( # # Preprocessor(test_set), # test_set, # batch_size=cfg.TEST.IMS_PER_BATCH, # num_workers=cfg.DATALOADER.NUM_WORKERS, # shuffle=False, # pin_memory=True, # ) test_batch_size = cfg.TEST.IMS_PER_BATCH mini_batch_size = test_batch_size // comm.get_world_size() num_workers = cfg.DATALOADER.NUM_WORKERS data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoaderX( comm.get_local_rank(), dataset=test_set, batch_sampler=batch_sampler, num_workers=num_workers, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_attr_test_loader(cfg, dataset_name): cfg = cfg.clone() cfg.defrost() dataset = DATASET_REGISTRY.get(dataset_name)( root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_test() test_items = dataset.test test_transforms = build_transforms(cfg, is_train=False) test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_face_test_loader(cfg, dataset_name, **kwargs): dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs) if comm.is_main_process(): dataset.show_test() test_set = FaceCommDataset(dataset.carray, dataset.is_same) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader, test_set.labels