def __init__(self, data_source, mini_batch_size:int, num_instances:int=4, seed: Optional[int] = None, with_mem_idx=False): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.with_mem_idx = with_mem_idx self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size self.index_pid = defaultdict(int) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] # camid = info[2] camid = info[3]['domains'] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.delete_rem = delete_rem self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() val_pid_index = [len(x) for x in self.pid_index.values()] min_v = min(val_pid_index) max_v = max(val_pid_index) hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)] num_print = 5 for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))): print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i])) print('...') print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v))) val_pid_index_upper = [] for x in val_pid_index: v_remain = x % self.num_instances if v_remain == 0: val_pid_index_upper.append(x) else: if self.delete_rem: if x < self.num_instances: val_pid_index_upper.append(x - v_remain + self.num_instances) else: val_pid_index_upper.append(x - v_remain) else: val_pid_index_upper.append(x - v_remain + self.num_instances) total_images = sum(val_pid_index_upper) total_images = total_images - (total_images % self.batch_size) - self.batch_size # approax self.total_images = total_images
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, set_weight: list, seed: Optional[int] = None): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.set_weight = set_weight self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \ self.batch_size > sum( self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight" self.index_pid = dict() self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) self.cam_pid = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.cam_pid[camid].append(pid) # Get sampler prob for each cam self.set_pid_prob = defaultdict(list) for camid, pid_list in self.cam_pid.items(): index_per_pid = [] for pid in pid_list: index_per_pid.append(len(self.pid_index[pid])) cam_image_number = sum(index_per_pid) prob = [i / cam_image_number for i in index_per_pid] self.set_pid_prob[camid] = prob self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None, callback_get_label: Callable = None): self.data_source = data_source # consider all elements in the dataset self.indices = list(range(len(data_source))) # if num_samples is not provided, draw `len(indices)` samples in each iteration self._size = len(self.indices) if size is None else size self.callback_get_label = callback_get_label # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(data_source, idx) label_to_count[label] = label_to_count.get(label, 0) + 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices] self.weights = torch.DoubleTensor(weights) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.delete_rem = delete_rem self.index_pid = defaultdict(list) self.pid_domain = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): domainid = info[3]['domains'] if cfg.DATALOADER.CAMERA_TO_DOMAIN: pid = info[1] + str(domainid) else: pid = info[1] self.index_pid[index] = pid # self.pid_domain[pid].append(domainid) self.pid_domain[pid] = domainid self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.domains = list(self.pid_domain.values()) self.num_identities = len(self.pids) self.num_domains = len(set(self.domains)) self.batch_size //= self.num_domains self.num_pids_per_batch //= self.num_domains if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() val_pid_index = [len(x) for x in self.pid_index.values()] min_v = min(val_pid_index) max_v = max(val_pid_index) hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)] num_print = 5 for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))): print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i])) print('...') print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v))) val_pid_index_upper = [] for x in val_pid_index: v_remain = x % self.num_instances if v_remain == 0: val_pid_index_upper.append(x) else: if self.delete_rem: if x < self.num_instances: val_pid_index_upper.append(x - v_remain + self.num_instances) else: val_pid_index_upper.append(x - v_remain) else: val_pid_index_upper.append(x - v_remain + self.num_instances) cnt_domains = [0 for x in range(self.num_domains)] for val, index in zip(val_pid_index_upper, self.domains): cnt_domains[index] += val self.max_cnt_domains = max(cnt_domains) self.total_images = self.num_domains * (self.max_cnt_domains - (self.max_cnt_domains % self.batch_size) - self.batch_size)
def build_reid_train_loader(cfg): # build datasets cfg = cfg.clone() frozen = cfg.is_frozen() cfg.defrost() individual_flag_ori = cfg.DATALOADER.INDIVIDUAL individual_flag_meta = cfg.META.DATA.INDIVIDUAL if cfg.META.DATA.NAMES == "": individual_flag_meta = False gettrace = getattr(sys, 'gettrace', None) if gettrace(): print('*' * 100) print('Hmm, Big Debugger is watching me') print('*' * 100) num_workers = 0 else: num_workers = cfg.DATALOADER.NUM_WORKERS # transforms train_transforms = build_transforms(cfg, is_train=True, is_fake=False) if (cfg.META.DATA.NAMES != "") and \ (cfg.META.DATA.LOADER_FLAG == 'synth' or cfg.META.DATA.SYNTH_FLAG is not 'none'): synth_transforms = build_transforms(cfg, is_train=True, is_fake=True) cfg.META.DATA.LOADER_FLAG = 'each' else: synth_transforms = None train_set_all = [] train_items = list() domain_idx = 0 camera_all = list() # load datasets for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() if len(dataset.train[0]) < 4: for i, x in enumerate(dataset.train): add_info = {} # dictionary if cfg.DATALOADER.CAMERA_TO_DOMAIN: add_info['domains'] = dataset.train[i][2] camera_all.append(dataset.train[i][2]) else: add_info['domains'] = int(domain_idx) dataset.train[i] = list(dataset.train[i]) dataset.train[i].append(add_info) dataset.train[i] = tuple(dataset.train[i]) domain_idx += 1 train_items.extend(dataset.train) if individual_flag_ori or individual_flag_meta: # individual set train_set_all.append(dataset.train) if cfg.DATALOADER.CAMERA_TO_DOMAIN: # used for single-source DG num_domains = len(set(camera_all)) else: num_domains = domain_idx cfg.META.DATA.NUM_DOMAINS = num_domains if cfg.DATALOADER.NAIVE_WAY: logger.info('**[dataloader info: random domain shuffle]**') else: logger.info('**[dataloader info: uniform domain]**') logger.info( '**[The batch size should be a multiple of the number of domains.]**' ) assert (cfg.SOLVER.IMS_PER_BATCH % (num_domains*cfg.DATALOADER.NUM_INSTANCE) == 0), \ "cfg.SOLVER.IMS_PER_BATCH should be a multiple of (num_domain x num_instance)" assert (cfg.META.DATA.MTRAIN_MINI_BATCH % (num_domains*cfg.META.DATA.MTRAIN_NUM_INSTANCE) == 0), \ "cfg.META.DATA.MTRAIN_MINI_BATCH should be a multiple of (num_domain x num_instance)" assert (cfg.META.DATA.MTEST_MINI_BATCH % (num_domains*cfg.META.DATA.MTEST_NUM_INSTANCE) == 0), \ "cfg.META.DATA.MTEST_MINI_BATCH should be a multiple of (num_domain x num_instance)" if individual_flag_ori: cfg.SOLVER.IMS_PER_BATCH //= num_domains if individual_flag_meta: cfg.META.DATA.MTRAIN_MINI_BATCH //= num_domains cfg.META.DATA.MTEST_MINI_BATCH //= num_domains if 'keypoint' in cfg.META.DATA.NAMES: # used for keypoint (not used in MetaBIN) cfg, train_set_all = make_keypoint_data(cfg=cfg, data_name=cfg.META.DATA.NAMES, train_items=train_items) train_set = CommDataset(train_items, train_transforms, relabel=True) if (synth_transforms is not None) and (cfg.META.DATA.NAMES != ""): # used for synthetic (not used in MetaBIN) synth_set = CommDataset(train_items, synth_transforms, relabel=True) if individual_flag_ori or individual_flag_meta: # for individual dataloader relabel_flag = False if individual_flag_meta: relabel_flag = cfg.META.DATA.RELABEL for i, x in enumerate(train_set_all): train_set_all[i] = CommDataset(x, train_transforms, relabel=relabel_flag) if not relabel_flag: train_set_all[i].relabel = True train_set_all[i].pid_dict = train_set.pid_dict # Check number of data cnt_data = 0 for x in train_set_all: cnt_data += len(x.img_items) if cnt_data != len(train_set.img_items): print("data loading error, check build.py") if individual_flag_ori: # for individual dataloader (domain-wise) train_loader = [] if len(train_set_all) > 0: for i, x in enumerate(train_set_all): train_loader.append( make_sampler(train_set=x, num_batch=cfg.SOLVER.IMS_PER_BATCH, num_instance=cfg.DATALOADER.NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), drop_last=cfg.DATALOADER.DROP_LAST, flag1=cfg.DATALOADER.NAIVE_WAY, flag2=cfg.DATALOADER.DELETE_REM, cfg=cfg)) else: train_loader = make_sampler(train_set=train_set, num_batch=cfg.SOLVER.IMS_PER_BATCH, num_instance=cfg.DATALOADER.NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size(), drop_last=cfg.DATALOADER.DROP_LAST, flag1=cfg.DATALOADER.NAIVE_WAY, flag2=cfg.DATALOADER.DELETE_REM, cfg=cfg) train_loader_add = {} train_loader_add['mtrain'] = None # mtrain dataset train_loader_add['mtest'] = None # mtest dataset if cfg.META.DATA.NAMES != "": if cfg.META.DATA.LOADER_FLAG == 'each': # "each": meta-init / meta-train / meta-test make_mtrain = True make_mtest = True elif cfg.META.DATA.LOADER_FLAG == 'diff': # "diff": meta-init / meta-final make_mtrain = True make_mtest = False elif cfg.META.DATA.LOADER_FLAG == 'same': # "same": meta-init make_mtrain = False make_mtest = False else: print('error in cfg.META.DATA.LOADER_FLAG') train_loader_add['mtrain'] = [] if make_mtrain else None train_loader_add['mtest'] = [] if make_mtest else None if cfg.META.DATA.SYNTH_SAME_SEED: seed = comm.shared_random_seed() else: seed = None if individual_flag_meta: # for individual dataset (domain-wise) for i, x in enumerate(train_set_all): if make_mtrain: train_loader_add['mtrain'].append( make_sampler( train_set=x, num_batch=cfg.META.DATA.MTRAIN_MINI_BATCH, num_instance=cfg.META.DATA.MTRAIN_NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.META.DATA.MTRAIN_MINI_BATCH // comm.get_world_size(), drop_last=cfg.META.DATA.DROP_LAST, flag1=cfg.META.DATA.NAIVE_WAY, flag2=cfg.META.DATA.DELETE_REM, seed=seed, cfg=cfg)) if make_mtest: train_loader_add['mtest'].append( make_sampler( train_set=x, num_batch=cfg.META.DATA.MTEST_MINI_BATCH, num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH // comm.get_world_size(), drop_last=cfg.META.DATA.DROP_LAST, flag1=cfg.META.DATA.NAIVE_WAY, flag2=cfg.META.DATA.DELETE_REM, seed=seed, cfg=cfg)) else: if make_mtrain: # meta train dataset train_loader_add['mtrain'] = make_sampler( train_set=train_set, num_batch=cfg.META.DATA.MTRAIN_MINI_BATCH, num_instance=cfg.META.DATA.MTRAIN_NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.META.DATA.MTRAIN_MINI_BATCH // comm.get_world_size(), drop_last=cfg.META.DATA.DROP_LAST, flag1=cfg.META.DATA.NAIVE_WAY, flag2=cfg.META.DATA.DELETE_REM, seed=seed, cfg=cfg) if make_mtest: # meta train dataset if synth_transforms is None: train_loader_add['mtest'] = make_sampler( train_set=train_set, num_batch=cfg.META.DATA.MTEST_MINI_BATCH, num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH // comm.get_world_size(), drop_last=cfg.META.DATA.DROP_LAST, flag1=cfg.META.DATA.NAIVE_WAY, flag2=cfg.META.DATA.DELETE_REM, seed=seed, cfg=cfg) else: train_loader_add['mtest'] = make_sampler( train_set=synth_set, num_batch=cfg.META.DATA.MTEST_MINI_BATCH, num_instance=cfg.META.DATA.MTEST_NUM_INSTANCE, num_workers=num_workers, mini_batch_size=cfg.META.DATA.MTEST_MINI_BATCH // comm.get_world_size(), drop_last=cfg.META.DATA.DROP_LAST, flag1=cfg.META.DATA.NAIVE_WAY, flag2=cfg.META.DATA.DELETE_REM, seed=seed, cfg=cfg) if frozen: cfg.freeze() return train_loader, train_loader_add, cfg