def build_attr_train_loader(cfg): train_items = list() attr_dict = None for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() if attr_dict is not None: assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones" else: attr_dict = dataset.attr_dict train_items.extend(dataset.train) train_transforms = build_transforms(cfg, is_train=True) train_set = AttrDataset(train_items, train_transforms, attr_dict) num_workers = cfg.DATALOADER.NUM_WORKERS mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def build_reid_train_loader(cfg): train_transforms = build_transforms(cfg, is_train=True) logger = logging.getLogger(__name__) train_items = list() for d in cfg.DATASETS.NAMES: logger.info('prepare training set {}'.format(d)) print('preparing training set...') dataset = DATASET_REGISTRY.get(d)(cfg) print(dataset) train_items.extend(dataset.train) train_set = BlackreidDataset(train_items, train_transforms, mode='train', relabel=True) num_workers = cfg.DATALOADER.NUM_WORKERS batch_size = cfg.SOLVER.IMS_PER_BATCH num_instance = cfg.DATALOADER.NUM_INSTANCE if cfg.DATALOADER.PK_SAMPLER: data_sampler = samplers.RandomIdentitySampler(train_set.img_items, batch_size, num_instance) else: data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler(data_sampler, batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, ) return data_prefetcher(cfg, train_loader)
def build_attr_test_loader(cfg, dataset_name): cfg = cfg.clone() cfg.defrost() dataset = DATASET_REGISTRY.get(dataset_name)( root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_test() test_items = dataset.test test_transforms = build_transforms(cfg, is_train=False) test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_train_loader(cls, cfg): """ Returns: iterable It now calls :func:`fastreid.data.build_reid_train_loader`. Overwrite it if you'd like a different data loader. """ logger = logging.getLogger("fastreid.clas_dataset") logger.info("Prepare training set") train_items = list() for d in cfg.DATASETS.NAMES: data = DATASET_REGISTRY.get(d)(root=_root) if comm.is_main_process(): data.show_train() train_items.extend(data.train) transforms = build_transforms(cfg, is_train=True) train_set = ClasDataset(train_items, transforms) data_loader = build_reid_train_loader(cfg, train_set=train_set) # Save index to class dictionary output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: path = os.path.join(output_dir, "idx2class.json") with PathManager.open(path, "w") as f: json.dump(train_set.idx_to_class, f) return data_loader
def build_train_loader(cls, cfg): path_imgrec = cfg.DATASETS.REC_PATH if path_imgrec != "": transforms = build_transforms(cfg, is_train=True) train_set = MXFaceDataset(path_imgrec, transforms) return build_reid_train_loader(cfg, train_set=train_set) else: return DefaultTrainer.build_train_loader(cfg)
def build_test_loader(cls, cfg, dataset_name): dataset = DATASET_REGISTRY.get(dataset_name)(root=_root) attr_dict = dataset.attr_dict if comm.is_main_process(): dataset.show_test() test_items = dataset.test test_transforms = build_transforms(cfg, is_train=False) test_set = AttrDataset(test_items, test_transforms, attr_dict) data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) return data_loader
def build_test_loader(cls, cfg, dataset_name): """ Returns: iterable It now calls :func:`fastreid.data.build_reid_test_loader`. Overwrite it if you'd like a different data loader. """ data = DATASET_REGISTRY.get(dataset_name)(root=_root) if comm.is_main_process(): data.show_test() transforms = build_transforms(cfg, is_train=False) test_set = ClasDataset(data.query, transforms) data_loader, _ = build_reid_test_loader(cfg, test_set=test_set) return data_loader
def build_dataset(cls, cfg, img_items, is_train=False, relabel=False, transforms=None, with_mem_idx=False): if transforms is None: transforms = build_transforms(cfg, is_train=is_train) if with_mem_idx: sorted_img_items = sorted(img_items) for i in range(len(sorted_img_items)): sorted_img_items[i] += (i, ) return InMemoryDataset(sorted_img_items, transforms, relabel) else: return CommDataset(img_items, transforms, relabel)
def build_train_loader(cls, cfg): logger = logging.getLogger("fastreid.attr_dataset") train_items = list() attr_dict = None for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)( root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() if attr_dict is not None: assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones" else: attr_dict = dataset.attr_dict train_items.extend(dataset.train) train_transforms = build_transforms(cfg, is_train=True) train_set = AttrDataset(train_items, train_transforms, attr_dict) data_loader = build_reid_train_loader(cfg, train_set=train_set) AttrTrainer.sample_weights = data_loader.dataset.sample_weights return data_loader
def build_reid_test_loader(cfg, dataset_name): test_transforms = build_transforms(cfg, is_train=False) logger = logging.getLogger(__name__) logger.info('prepare test set {}'.format(dataset_name)) print('preparing test set...') dataset = DATASET_REGISTRY.get(dataset_name)(cfg) print(dataset) test_items = dataset.query + dataset.gallery test_set = BlackreidDataset(test_items, test_transforms, mode='test', relabel=False) num_workers = cfg.DATALOADER.NUM_WORKERS batch_size = cfg.TEST.IMS_PER_BATCH data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=fast_batch_collator) return data_prefetcher(cfg, test_loader), len(dataset.query)
def build_train_loader(cls, cfg): """ Returns: iterable It now calls :func:`fastreid.data.build_reid_train_loader`. Overwrite it if you'd like a different data loader. """ logger = logging.getLogger("fastreid.clas_dataset") logger.info("Prepare training set") train_items = list() for d in cfg.DATASETS.NAMES: data = DATASET_REGISTRY.get(d)(root=_root) if comm.is_main_process(): data.show_train() train_items.extend(data.train) transforms = build_transforms(cfg, is_train=True) train_set = ClasDataset(train_items, transforms) cls.idx2class = train_set.idx_to_class data_loader = build_reid_train_loader(cfg, train_set=train_set) return data_loader
def eval_ARI_purity(cls, cfg, model, transforms=None, **kwargs): num_devices = comm.get_world_size() logger = logging.getLogger('fastreid') _root = os.getenv("FASTREID_DATASETS", "/root/datasets") evaluator = ARI_Purity_Evaluator(cfg) results = OrderedDict() if transforms is None: transforms = build_transforms(cfg, is_train=False) total_datasets = set(cfg.DATASETS.TESTS) for dataset_name in total_datasets: logger.info(f"Starting evaluating ARI on dataset {dataset_name}") data = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs) test_items = data.train test_set = CommDataset(test_items, transforms, relabel=True) data_loader, num_query = build_reid_test_loader(cfg, dataset_name=dataset_name, test_set=test_set) total = len(data_loader) # inference data loader must have a fixed length # print('data_loader len =', total) evaluator.reset() img_nums = len(test_items) num_warmup = min(5, total - 1) start_time = time.perf_counter() total_compute_time = 0 with inference_context(model), torch.no_grad(): for idx, inputs in enumerate(data_loader): # print(inputs) if idx == num_warmup: start_time = time.perf_counter() total_compute_time = 0 start_compute_time = time.perf_counter() outputs = model(inputs) # Flip test if cfg.TEST.FLIP.ENABLED: inputs["images"] = inputs["images"].flip(dims=[3]) flip_outputs = model(inputs) outputs = (outputs + flip_outputs) / 2 if torch.cuda.is_available(): torch.cuda.synchronize() total_compute_time += time.perf_counter() - start_compute_time evaluator.process(inputs, outputs) iters_after_start = idx + 1 - num_warmup * int(idx >= num_warmup) seconds_per_batch = total_compute_time / iters_after_start if idx >= num_warmup * 2 or seconds_per_batch > 30: total_seconds_per_img = (time.perf_counter() - start_time) / iters_after_start eta = datetime.timedelta(seconds=int(total_seconds_per_img * (total - idx - 1))) log_every_n_seconds( logging.INFO, "Inference done {}/{}. {:.4f} s / batch. ETA={}".format( idx + 1, total, seconds_per_batch, str(eta) ), n=30, ) # Measure the time only for this worker (before the synchronization barrier) total_time = time.perf_counter() - start_time total_time_str = str(datetime.timedelta(seconds=total_time)) # NOTE this format is parsed by grep logger.info( "Total inference time: {} ({:.6f} s / batch per device, on {} devices)".format( total_time_str, total_time / (total - num_warmup), num_devices ) ) total_compute_time_str = str(datetime.timedelta(seconds=int(total_compute_time))) logger.info( "Total inference pure compute time: {} ({:.6f} s / batch per device, on {} devices)".format( total_compute_time_str, total_compute_time / (total - num_warmup), num_devices ) ) results_i = evaluator.evaluate() ARI_score, purity = results_i results[f'{dataset_name}/ARI'] = ARI_score results[f'{dataset_name}/purity'] = purity if comm.is_main_process(): assert isinstance( results, dict ), "Evaluator must return a dict on the main process. Got {} instead.".format( results ) logger.info(f"ARI score for {dataset_name} is {ARI_score:.4f}") logger.info(f"Purity for {dataset_name} is {purity:.4f}") return results