def __init__(self, data_source, mini_batch_size:int, num_instances:int=4, seed: Optional[int] = None, with_mem_idx=False): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.with_mem_idx = with_mem_idx self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size self.index_pid = defaultdict(int) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, set_weight: list, seed: Optional[int] = None): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.set_weight = set_weight self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \ self.batch_size > sum( self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight" self.index_pid = dict() self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) self.cam_pid = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.cam_pid[camid].append(pid) # Get sampler prob for each cam self.set_pid_prob = defaultdict(list) for camid, pid_list in self.cam_pid.items(): index_per_pid = [] for pid in pid_list: index_per_pid.append(len(self.pid_index[pid])) cam_image_number = sum(index_per_pid) prob = [i / cam_image_number for i in index_per_pid] self.set_pid_prob[camid] = prob self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def build_cls_test_loader(cfg, dataset_name, mapper=None, **kwargs): cfg = cfg.clone() dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs) if comm.is_main_process(): dataset.show_test() test_items = dataset.query if mapper is not None: transforms = mapper else: transforms = build_transforms(cfg, is_train=False) test_set = CommDataset(test_items, transforms, relabel=False) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_reid_test_loader(test_set, test_batch_size, num_query, num_workers=4): """ Similar to `build_reid_train_loader`. This sampler coordinates all workers to produce the exact set of all samples This interface is experimental. Args: test_set: test_batch_size: num_query: num_workers: Returns: DataLoader: a torch DataLoader, that loads the given reid dataset, with the test-time transformation. Examples: :: data_loader = build_reid_test_loader(test_set, test_batch_size, num_query) # or, instantiate with a CfgNode: data_loader = build_reid_test_loader(cfg, "my_test") """ mini_batch_size = test_batch_size // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=num_workers, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader, num_query
def build_attr_train_loader(cfg): train_items = list() attr_dict = None for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() if attr_dict is not None: assert attr_dict == dataset.attr_dict, f"attr_dict in {d} does not match with previous ones" else: attr_dict = dataset.attr_dict train_items.extend(dataset.train) train_transforms = build_transforms(cfg, is_train=True) train_set = AttrDataset(train_items, train_transforms, attr_dict) num_workers = cfg.DATALOADER.NUM_WORKERS mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def build_reid_train_loader( train_set, *, sampler=None, total_batch_size, num_workers=0, ): """ Build a dataloader for object re-identification with some default features. This interface is experimental. Returns: torch.utils.data.DataLoader: a dataloader. """ mini_batch_size = total_batch_size // comm.get_world_size() batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def build_reid_test_loader(cfg, dataset_name): cfg = cfg.clone() cfg.defrost() dataset = DATASET_REGISTRY.get(dataset_name)( root=_root, dataset_name=cfg.SPECIFIC_DATASET) if comm.is_main_process(): dataset.show_test() test_items = dataset.query + dataset.gallery test_transforms = build_transforms(cfg, is_train=False) test_set = CommDataset(test_items, test_transforms, relabel=False) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=0, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader, len(dataset.query)
def build_attr_test_loader(cfg, dataset_name): cfg = cfg.clone() cfg.defrost() dataset = DATASET_REGISTRY.get(dataset_name)( root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_test() test_items = dataset.test test_transforms = build_transforms(cfg, is_train=False) test_set = AttrDataset(test_items, dataset.attr_dict, test_transforms) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() predictions = comm.gather(self._predictions, dst=0) predictions = list(itertools.chain(*predictions)) if not comm.is_main_process(): return {} else: predictions = self._predictions features = [] pids = [] # camids = [] for prediction in predictions: features.append(prediction['feats']) pids.append(prediction['pids']) # camids.append(prediction['camids']) features = torch.cat(features, dim=0) pids = torch.cat(pids, dim=0).numpy() rerank_dist = compute_jaccard_distance( features, k1=self.cfg.CLUSTER.JACCARD.K1, k2=self.cfg.CLUSTER.JACCARD.K2, ) pseudo_labels = self.cluster.fit_predict(rerank_dist) contingency_matrix = metrics.cluster.contingency_matrix( pids, pseudo_labels) purity = np.sum(np.amax(contingency_matrix, axis=0)) / np.sum(contingency_matrix) return purity
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() predictions = comm.gather(self._predictions, dst=0) predictions = list(itertools.chain(*predictions)) if not comm.is_main_process(): return {} else: predictions = self._predictions pred_logits = [] labels = [] for prediction in predictions: pred_logits.append(prediction['logits']) labels.append(prediction['labels']) pred_logits = torch.cat(pred_logits, dim=0) labels = torch.cat(labels, dim=0) # measure accuracy and record loss acc1, = accuracy(pred_logits, labels, topk=(1, )) self._results = OrderedDict() self._results["Acc@1"] = acc1 self._results["metric"] = acc1 return copy.deepcopy(self._results)
def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] # camid = info[2] camid = info[3]['domains'] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() predictions = comm.gather(self._predictions, dst=0) predictions = list(itertools.chain(*predictions)) if not comm.is_main_process(): return {} else: predictions = self._predictions features = [] pids = [] # camids = [] for prediction in predictions: features.append(prediction['feats']) pids.append(prediction['pids']) # camids.append(prediction['camids']) features = torch.cat(features, dim=0) pids = torch.cat(pids, dim=0).numpy() rerank_dist = compute_jaccard_distance( features, k1=self.cfg.CLUSTER.JACCARD.K1, k2=self.cfg.CLUSTER.JACCARD.K2, ) pseudo_labels = self.cluster.fit_predict(rerank_dist) ARI_score = metrics.adjusted_rand_score(pids, pseudo_labels) return ARI_score
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() pred_logits = comm.gather(self.pred_logits) pred_logits = sum(pred_logits, []) labels = comm.gather(self.labels) labels = sum(labels, []) # fmt: off if not comm.is_main_process(): return {} # fmt: on else: pred_logits = self.pred_logits labels = self.labels pred_logits = torch.cat(pred_logits, dim=0) labels = torch.stack(labels) # measure accuracy and record loss acc1, = accuracy(pred_logits, labels, topk=(1, )) self._results = OrderedDict() self._results["Acc@1"] = acc1 self._results["metric"] = acc1 return copy.deepcopy(self._results)
def __call__(self, embedding, targets): embedding = F.normalize(embedding, dim=1) if comm.get_world_size() > 1: all_embedding = concat_all_gather(embedding) all_targets = concat_all_gather(targets) else: all_embedding = embedding all_targets = targets dist_mat = torch.matmul(embedding, all_embedding.t()) N, M = dist_mat.size() is_pos = targets.view(N, 1).expand(N, M).eq( all_targets.view(M, 1).expand(M, N).t()) is_neg = targets.view(N, 1).expand(N, M).ne( all_targets.view(M, 1).expand(M, N).t()) s_p = dist_mat[is_pos].contiguous().view(N, -1) s_n = dist_mat[is_neg].contiguous().view(N, -1) alpha_p = F.relu(-s_p.detach() + 1 + self.m) alpha_n = F.relu(s_n.detach() + self.m) delta_p = 1 - self.m delta_n = self.m logit_p = -self.s * alpha_p * (s_p - delta_p) logit_n = self.s * alpha_n * (s_n - delta_n) loss = F.softplus( torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() return loss * self._scale
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() features = comm.gather(self.features) features = sum(features, []) # fmt: off if not comm.is_main_process(): return {} # fmt: on else: features = self.features features = torch.cat(features, dim=0) features = F.normalize(features, p=2, dim=1).numpy() self._results = OrderedDict() tpr, fpr, accuracy, best_thresholds = evaluate(features, self.labels) self._results["Accuracy"] = accuracy.mean() * 100 self._results["Threshold"] = best_thresholds.mean() self._results["metric"] = accuracy.mean() * 100 buf = gen_plot(fpr, tpr) roc_curve = Image.open(buf) PathManager.mkdirs(self._output_dir) roc_curve.save( os.path.join(self._output_dir, self.dataset_name + "_roc.png")) return copy.deepcopy(self._results)
def build_test_loader(cls, cfg, test_set): logger = logging.getLogger('fastreid') logger.info("Prepare testing loader") # test_loader = DataLoader( # # Preprocessor(test_set), # test_set, # batch_size=cfg.TEST.IMS_PER_BATCH, # num_workers=cfg.DATALOADER.NUM_WORKERS, # shuffle=False, # pin_memory=True, # ) test_batch_size = cfg.TEST.IMS_PER_BATCH mini_batch_size = test_batch_size // comm.get_world_size() num_workers = cfg.DATALOADER.NUM_WORKERS data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoaderX( comm.get_local_rank(), dataset=test_set, batch_sampler=batch_sampler, num_workers=num_workers, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader
def build_train_loader(cls, cfg, train_set=None, sampler=None, with_mem_idx=False): logger = logging.getLogger('fastreid') logger.info("Prepare training loader") total_batch_size = cfg.SOLVER.IMS_PER_BATCH mini_batch_size = total_batch_size // comm.get_world_size() if sampler is None: num_instance = cfg.DATALOADER.NUM_INSTANCE sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(train_set)) elif sampler_name == "NaiveIdentitySampler": sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "BalancedIdentitySampler": sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) elif sampler_name == "ImbalancedDatasetSampler": sampler = samplers.ImbalancedDatasetSampler( train_set.img_items) else: raise ValueError( "Unknown training sampler: {}".format(sampler_name)) iters = cfg.SOLVER.ITERS num_workers = cfg.DATALOADER.NUM_WORKERS batch_sampler = BatchSampler(sampler, mini_batch_size, True) train_loader = IterLoader( DataLoader( Preprocessor(train_set, with_mem_idx), num_workers=num_workers, batch_sampler=batch_sampler, pin_memory=True, ), length=iters, ) # train_loader = DataLoaderX( # comm.get_local_rank(), # dataset=Preprocessor(train_set, with_mem_idx), # num_workers=num_workers, # batch_sampler=batch_sampler, # collate_fn=fast_batch_collator, # pin_memory=True, # ) return train_loader
def __init__(self, cfg): """ Args: cfg (CfgNode): """ super().__init__() logger = logging.getLogger("fastreid") if not logger.isEnabledFor( logging.INFO): # setup_logger is not called for fastreid setup_logger() # Assume these objects must be constructed in this order. data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader.dataset.num_classes) model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) optimizer_ckpt = dict(optimizer=optimizer) if cfg.SOLVER.FP16_ENABLED: model, optimizer = amp.initialize(model, optimizer, opt_level="O1") optimizer_ckpt.update(dict(amp=amp)) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. # model = DistributedDataParallel( # model, device_ids=[comm.get_local_rank()], broadcast_buffers=False # ) model = DistributedDataParallel(model, delay_allreduce=True) self._trainer = (AMPTrainer if cfg.SOLVER.FP16_ENABLED else SimpleTrainer)(model, data_loader, optimizer) self.iters_per_epoch = len( data_loader.dataset) // cfg.SOLVER.IMS_PER_BATCH self.scheduler = self.build_lr_scheduler(cfg, optimizer, self.iters_per_epoch) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), **optimizer_ckpt, **self.scheduler, ) self.start_epoch = 0 self.max_epoch = cfg.SOLVER.MAX_EPOCH self.max_iter = self.max_epoch * self.iters_per_epoch self.warmup_iters = cfg.SOLVER.WARMUP_ITERS self.delay_epochs = cfg.SOLVER.DELAY_EPOCHS self.cfg = cfg self.register_hooks(self.build_hooks())
def __call__(self, embedding, targets): """Forward pass for all input predictions: preds - (batch_size x feat_dims) """ # ------ differentiable ranking of all retrieval set ------ embedding = F.normalize(embedding, dim=1) feat_dim = embedding.size(1) # For distributed training, gather all features from different process. if comm.get_world_size() > 1: all_embedding = concat_all_gather(embedding) all_targets = concat_all_gather(targets) else: all_embedding = embedding all_targets = targets sim_dist = torch.matmul(embedding, all_embedding.t()) N, M = sim_dist.size() # Compute the mask which ignores the relevance score of the query to itself mask_indx = 1.0 - torch.eye(M, device=sim_dist.device) mask_indx = mask_indx.unsqueeze(dim=0).repeat(N, 1, 1) # (N, M, M) # sim_dist -> N, 1, M -> N, M, N sim_dist_repeat = sim_dist.unsqueeze(dim=1).repeat(1, M, 1) # (N, M, M) # sim_dist_repeat_t = sim_dist.t().unsqueeze(dim=1).repeat(1, N, 1) # (N, N, M) # Compute the difference matrix sim_diff = sim_dist_repeat - sim_dist_repeat.permute(0, 2, 1) # (N, M, M) # Pass through the sigmoid sim_sg = sigmoid(sim_diff, temp=self.anneal) * mask_indx # Compute all the rankings sim_all_rk = torch.sum(sim_sg, dim=-1) + 1 # (N, N) pos_mask = targets.view(N, 1).expand(N, M).eq( all_targets.view(M, 1).expand(M, N).t()).float() # (N, M) pos_mask_repeat = pos_mask.unsqueeze(1).repeat(1, M, 1) # (N, M, M) # Compute positive rankings pos_sim_sg = sim_sg * pos_mask_repeat sim_pos_rk = torch.sum(pos_sim_sg, dim=-1) + 1 # (N, N) # sum the values of the Smooth-AP for all instances in the mini-batch ap = 0 group = N // self.num_id for ind in range(self.num_id): pos_divide = torch.sum( sim_pos_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)] / (sim_all_rk[(ind * group):((ind + 1) * group), (ind * group):((ind + 1) * group)])) ap += pos_divide / torch.sum(pos_mask[ind * group]) / N return 1 - ap
def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.delete_rem = delete_rem self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() val_pid_index = [len(x) for x in self.pid_index.values()] min_v = min(val_pid_index) max_v = max(val_pid_index) hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)] num_print = 5 for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))): print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i])) print('...') print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v))) val_pid_index_upper = [] for x in val_pid_index: v_remain = x % self.num_instances if v_remain == 0: val_pid_index_upper.append(x) else: if self.delete_rem: if x < self.num_instances: val_pid_index_upper.append(x - v_remain + self.num_instances) else: val_pid_index_upper.append(x - v_remain) else: val_pid_index_upper.append(x - v_remain + self.num_instances) total_images = sum(val_pid_index_upper) total_images = total_images - (total_images % self.batch_size) - self.batch_size # approax self.total_images = total_images
def __init__(self, embedding_size, num_classes, sample_rate, cls_type, scale, margin): super().__init__() self.embedding_size = embedding_size self.num_classes = num_classes self.sample_rate = sample_rate self.world_size = comm.get_world_size() self.rank = comm.get_rank() self.local_rank = comm.get_local_rank() self.device = torch.device(f'cuda:{self.local_rank}') self.num_local: int = self.num_classes // self.world_size + int( self.rank < self.num_classes % self.world_size) self.class_start: int = self.num_classes // self.world_size * self.rank + \ min(self.rank, self.num_classes % self.world_size) self.num_sample: int = int(self.sample_rate * self.num_local) self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin) """ TODO: consider resume training if resume: try: self.weight: torch.Tensor = torch.load(self.weight_name) logging.info("softmax weight resume successfully!") except (FileNotFoundError, KeyError, IndexError): self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) logging.info("softmax weight resume fail!") try: self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) logging.info("softmax weight mom resume successfully!") except (FileNotFoundError, KeyError, IndexError): self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logging.info("softmax weight mom resume fail!") else: """ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logger.info("softmax weight init successfully!") logger.info("softmax weight mom init successfully!") self.stream: torch.cuda.Stream = torch.cuda.Stream(self.local_rank) self.index = None if int(self.sample_rate) == 1: self.update = lambda: 0 self.sub_weight = nn.Parameter(self.weight) self.sub_weight_mom = self.weight_mom else: self.sub_weight = nn.Parameter( torch.empty((0, 0), device=self.device))
def build_reid_train_loader(cfg, mapper=None, **kwargs): """ Build reid train loader Args: cfg : image file path mapper : one of the supported image modes in PIL, or "BGR" Returns: torch.utils.data.DataLoader: a dataloader. """ cfg = cfg.clone() train_items = list() for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL, **kwargs) if comm.is_main_process(): dataset.show_train() train_items.extend(dataset.train) if mapper is not None: transforms = mapper else: transforms = build_transforms(cfg, is_train=True) train_set = CommDataset(train_items, transforms, relabel=True) num_workers = cfg.DATALOADER.NUM_WORKERS num_instance = cfg.DATALOADER.NUM_INSTANCE mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() if cfg.DATALOADER.PK_SAMPLER: if cfg.DATALOADER.NAIVE_WAY: data_sampler = samplers.NaiveIdentitySampler( train_set.img_items, mini_batch_size, num_instance) else: data_sampler = samplers.BalancedIdentitySampler( train_set.img_items, mini_batch_size, num_instance) else: data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def _train_loader_from_config(cfg, *, train_set=None, transforms=None, sampler=None, **kwargs): if transforms is None: transforms = build_transforms(cfg, is_train=True) if train_set is None: train_items = list() for d in cfg.DATASETS.NAMES: data = DATASET_REGISTRY.get(d)(root=_root, **kwargs) if comm.is_main_process(): data.show_train() train_items.extend(data.train) train_set = CommDataset(train_items, transforms, relabel=True) if sampler is None: sampler_name = cfg.DATALOADER.SAMPLER_TRAIN num_instance = cfg.DATALOADER.NUM_INSTANCE mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(train_set)) elif sampler_name == "NaiveIdentitySampler": sampler = samplers.NaiveIdentitySampler(train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "BalancedIdentitySampler": sampler = samplers.BalancedIdentitySampler(train_set.img_items, mini_batch_size, num_instance) elif sampler_name == "SetReWeightSampler": set_weight = cfg.DATALOADER.SET_WEIGHT sampler = samplers.SetReWeightSampler(train_set.img_items, mini_batch_size, num_instance, set_weight) elif sampler_name == "ImbalancedDatasetSampler": sampler = samplers.ImbalancedDatasetSampler(train_set.img_items) else: raise ValueError( "Unknown training sampler: {}".format(sampler_name)) return { "train_set": train_set, "sampler": sampler, "total_batch_size": cfg.SOLVER.IMS_PER_BATCH, "num_workers": cfg.DATALOADER.NUM_WORKERS, }
def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end)
def build_reid_train_loader(cfg): cfg = cfg.clone() cfg.defrost() train_items = list() for d in cfg.DATASETS.NAMES: dataset = DATASET_REGISTRY.get(d)(root=_root, combineall=cfg.DATASETS.COMBINEALL) if comm.is_main_process(): dataset.show_train() train_items.extend(dataset.train) iters_per_epoch = len(train_items) // cfg.SOLVER.IMS_PER_BATCH cfg.SOLVER.MAX_ITER *= iters_per_epoch train_transforms = build_transforms(cfg, is_train=True) if not cfg.DATALOADER.IS_CLO_CHANGES: train_set = CommDataset(train_items, train_transforms, relabel=True) else: # For clothes changes datasets train_set = CCDatasets(train_items, train_transforms, relabel=True) num_workers = cfg.DATALOADER.NUM_WORKERS num_instance = cfg.DATALOADER.NUM_INSTANCE mini_batch_size = cfg.SOLVER.IMS_PER_BATCH // comm.get_world_size() if cfg.DATALOADER.PK_SAMPLER: if cfg.DATALOADER.NAIVE_WAY: data_sampler = samplers.NaiveIdentitySampler( train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, num_instance, None, True) else: data_sampler = samplers.BalancedIdentitySampler( train_set.img_items, cfg.SOLVER.IMS_PER_BATCH, num_instance) else: data_sampler = samplers.TrainingSampler(len(train_set)) batch_sampler = torch.utils.data.sampler.BatchSampler( data_sampler, mini_batch_size, True) train_loader = torch.utils.data.DataLoader( train_set, num_workers=num_workers, batch_sampler=batch_sampler, collate_fn=fast_batch_collator, pin_memory=True, ) return train_loader
def __init__(self, cfg): """ Args: cfg (CfgNode): """ logger = logging.getLogger("fastreid") if not logger.isEnabledFor( logging.INFO): # setup_logger is not called for fastreid setup_logger() # Assume these objects must be constructed in this order. data_loader = self.build_train_loader(cfg) cfg = self.auto_scale_hyperparams(cfg, data_loader) model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: # ref to https://github.com/pytorch/pytorch/issues/22049 to set `find_unused_parameters=True` # for part of the parameters is not updated. model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()], broadcast_buffers=False) super().__init__(model, data_loader, optimizer, cfg.SOLVER.BASE_LR, cfg.MODEL.LOSSES.CENTER.LR, cfg.MODEL.LOSSES.CENTER.SCALE, cfg.SOLVER.AMP_ENABLED) self.scheduler = self.build_lr_scheduler(cfg, optimizer) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = Checkpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, save_to_disk=comm.is_main_process(), optimizer=optimizer, scheduler=self.scheduler, ) self.start_iter = 0 if cfg.SOLVER.SWA.ENABLED: self.max_iter = cfg.SOLVER.MAX_ITER + cfg.SOLVER.SWA.ITER else: self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks())
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the detectron2 logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (CfgNode): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() setup_logger(output_dir, distributed_rank=rank, name="fvcore") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info( "Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read() ) ) logger.info("Running with full config:\n{}".format(cfg)) if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng() # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
def evaluate(self): if comm.get_world_size() > 1: comm.synchronize() features = comm.gather(self.features) features = sum(features, []) labels = comm.gather(self.labels) labels = sum(labels, []) # fmt: off if not comm.is_main_process(): return {} # fmt: on else: features = self.features labels = self.labels features = torch.cat(features, dim=0) # query feature, person ids and camera ids query_features = features[:self._num_query] query_labels = np.asarray(labels[:self._num_query]) # gallery features, person ids and camera ids gallery_features = features[self._num_query:] gallery_pids = np.asarray(labels[self._num_query:]) self._results = OrderedDict() if self._num_query == len(features): cmc = recall_at_ks(query_features, query_labels, self.recalls, cosine=True) else: cmc = recall_at_ks(query_features, query_labels, self.recalls, gallery_features, gallery_pids, cosine=True) for r in self.recalls: self._results['Recall@{}'.format(r)] = cmc[r] self._results["metric"] = cmc[self.recalls[0]] return copy.deepcopy(self._results)
def __call__(self, embedding, targets): embedding = nn.functional.normalize(embedding, dim=1) if comm.get_world_size() > 1: all_embedding = concat_all_gather(embedding) all_targets = concat_all_gather(targets) else: all_embedding = embedding all_targets = targets dist_mat = torch.matmul(embedding, all_embedding.t()) N, M = dist_mat.size() is_pos = targets.view(N, 1).expand(N, M).eq( all_targets.view(M, 1).expand(M, N).t()).float() # Compute the mask which ignores the relevance score of the query to itself if M > N: identity_indx = torch.eye(N, N, device=is_pos.device) remain_indx = torch.zeros(N, M - N, device=is_pos.device) identity_indx = torch.cat((identity_indx, remain_indx), dim=1) is_pos = is_pos - identity_indx else: is_pos = is_pos - torch.eye(N, N, device=is_pos.device) is_neg = targets.view(N, 1).expand(N, M).ne( all_targets.view(M, 1).expand(M, N).t()) s_p = dist_mat * is_pos s_n = dist_mat * is_neg alpha_p = torch.clamp_min(-s_p.detach() + 1 + self._m, min=0.) alpha_n = torch.clamp_min(s_n.detach() + self._m, min=0.) delta_p = 1 - self._m delta_n = self._m logit_p = -self._s * alpha_p * (s_p - delta_p) logit_n = self._s * alpha_n * (s_n - delta_n) loss = nn.functional.softplus( torch.logsumexp(logit_p, dim=1) + torch.logsumexp(logit_n, dim=1)).mean() return loss * self._scale
def build_face_test_loader(cfg, dataset_name, **kwargs): dataset = DATASET_REGISTRY.get(dataset_name)(root=_root, **kwargs) if comm.is_main_process(): dataset.show_test() test_set = FaceCommDataset(dataset.carray, dataset.is_same) mini_batch_size = cfg.TEST.IMS_PER_BATCH // comm.get_world_size() data_sampler = samplers.InferenceSampler(len(test_set)) batch_sampler = torch.utils.data.BatchSampler(data_sampler, mini_batch_size, False) test_loader = DataLoader( test_set, batch_sampler=batch_sampler, num_workers=4, # save some memory collate_fn=fast_batch_collator, pin_memory=True, ) return test_loader, test_set.labels