def build_attr_train_loader(cfg): train_items = list() attr_dict = None for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() if attr_dict is not None: assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones" else: attr_dict = dataset.attr_dict train_items.extend(dataset.train) train_transforms = build_transforms(cfg, is_train=True) train_set = AttrDataset(train_items, train_transforms, attr_dict) num_workers = cfg.DATALOADER.NUM_WORKERS mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def build_train_loader(cls, cfg, train_set=None, sampler=None, with_mem_idx=False): logger = logging.getLogger('fastreid') logger.info("Prepare training loader") total_batch_size = cfg.SOLVER.IMS_PER_BATCH mini_batch_size = total_batch_size // comm.get_world_size() if sampler is None: num_instance = cfg.DATALOADER.NUM_INSTANCE sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(train_set)) elif sampler_name == "NaiveIdentitySampler": sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "BalancedIdentitySampler": sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) elif sampler_name == "ImbalancedDatasetSampler": sampler = samplers.ImbalancedDatasetSampler( train_set.img_items) else: raise ValueError( "Unknown training sampler: {}".format(sampler_name)) iters = cfg.SOLVER.ITERS num_workers = cfg.DATALOADER.NUM_WORKERS batch_sampler = BatchSampler(sampler, mini_batch_size, True) train_loader = IterLoader( DataLoader( Preprocessor(train_set, with_mem_idx), num_workers=num_workers, batch_sampler=batch_sampler, pin_memory=True, ), length=iters, ) # train_loader = DataLoaderX( # comm.get_local_rank(), # dataset=Preprocessor(train_set, with_mem_idx), # num_workers=num_workers, # batch_sampler=batch_sampler, # collate_fn=fast_batch_collator, # pin_memory=True, # ) return train_loader
def build_cls_train_loader(cfg, mapper=None, **kwargs): cfg = cfg.clone() train_items = list() for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, **kwargs) if comm.is_main_process(): dataset.show_train() train_items.extend(dataset.train) if mapper is not None: transforms = mapper else: transforms = build_transforms(cfg, is_train=True) train_set = CommDataset(train_items, transforms, relabel=False) num_workers = cfg.DATALOADER.NUM_WORKERS num_instance = cfg.DATALOADER.NUM_INSTANCE mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() if cfg.DATALOADER.PK_SAMPLER: if cfg.DATALOADER.NAIVE_WAY: data_sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) else: data_sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) else: data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader