def __init__(self, data_source, mini_batch_size:int, num_instances:int=4, seed: Optional[int] = None, with_mem_idx=False): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.with_mem_idx = with_mem_idx self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size self.index_pid = defaultdict(int) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, mini_batch_size: int, num_instances: int, set_weight: list, seed: Optional[int] = None): self.data_source = data_source self.num_instances = num_instances self.num_pids_per_batch = mini_batch_size // self.num_instances self.set_weight = set_weight self._rank = comm.get_rank() self._world_size = comm.get_world_size() self.batch_size = mini_batch_size * self._world_size assert self.batch_size % (sum(self.set_weight) * self.num_instances) == 0 and \ self.batch_size > sum( self.set_weight) * self.num_instances, "Batch size must be divisible by the sum set weight" self.index_pid = dict() self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) self.cam_pid = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.cam_pid[camid].append(pid) # Get sampler prob for each cam self.set_pid_prob = defaultdict(list) for camid, pid_list in self.cam_pid.items(): index_per_pid = [] for pid in pid_list: index_per_pid.append(len(self.pid_index[pid])) cam_image_number = sum(index_per_pid) prob = [i / cam_image_number for i in index_per_pid] self.set_pid_prob[camid] = prob self.pids = sorted(list(self.pid_index.keys())) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, batch_size: int, num_instances: int, seed: Optional[int] = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] # camid = info[2] camid = info[3]['domains'] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def step(self, epoch: int, **kwargs: Any): rank = comm.get_rank() if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1: self.checkpointer.save( f"softmax_weight_{epoch:04d}_rank_{rank:02d}") if epoch >= self.max_epoch - 1: self.checkpointer.save(f"softmax_weight_{rank:02d}", )
def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.delete_rem = delete_rem self.index_pid = defaultdict(list) self.pid_cam = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): pid = info[1] camid = info[2] self.index_pid[index] = pid self.pid_cam[pid].append(camid) self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.num_identities = len(self.pids) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() val_pid_index = [len(x) for x in self.pid_index.values()] min_v = min(val_pid_index) max_v = max(val_pid_index) hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)] num_print = 5 for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))): print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i])) print('...') print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v))) val_pid_index_upper = [] for x in val_pid_index: v_remain = x % self.num_instances if v_remain == 0: val_pid_index_upper.append(x) else: if self.delete_rem: if x < self.num_instances: val_pid_index_upper.append(x - v_remain + self.num_instances) else: val_pid_index_upper.append(x - v_remain) else: val_pid_index_upper.append(x - v_remain + self.num_instances) total_images = sum(val_pid_index_upper) total_images = total_images - (total_images % self.batch_size) - self.batch_size # approax self.total_images = total_images
def __init__(self, model, save_dir, *, save_to_disk=True, **checkpointables): super().__init__(model, save_dir, save_to_disk=save_to_disk, **checkpointables) self.rank = comm.get_rank()
def __init__(self, embedding_size, num_classes, sample_rate, cls_type, scale, margin): super().__init__() self.embedding_size = embedding_size self.num_classes = num_classes self.sample_rate = sample_rate self.world_size = comm.get_world_size() self.rank = comm.get_rank() self.local_rank = comm.get_local_rank() self.device = torch.device(f'cuda:{self.local_rank}') self.num_local: int = self.num_classes // self.world_size + int( self.rank < self.num_classes % self.world_size) self.class_start: int = self.num_classes // self.world_size * self.rank + \ min(self.rank, self.num_classes % self.world_size) self.num_sample: int = int(self.sample_rate * self.num_local) self.cls_layer = getattr(any_softmax, cls_type)(num_classes, scale, margin) """ TODO: consider resume training if resume: try: self.weight: torch.Tensor = torch.load(self.weight_name) logging.info("softmax weight resume successfully!") except (FileNotFoundError, KeyError, IndexError): self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) logging.info("softmax weight resume fail!") try: self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name) logging.info("softmax weight mom resume successfully!") except (FileNotFoundError, KeyError, IndexError): self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logging.info("softmax weight mom resume fail!") else: """ self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device) self.weight_mom: torch.Tensor = torch.zeros_like(self.weight) logger.info("softmax weight init successfully!") logger.info("softmax weight mom init successfully!") self.stream: torch.cuda.Stream = torch.cuda.Stream(self.local_rank) self.index = None if int(self.sample_rate) == 1: self.update = lambda: 0 self.sub_weight = nn.Parameter(self.weight) self.sub_weight_mom = self.weight_mom else: self.sub_weight = nn.Parameter( torch.empty((0, 0), device=self.device))
def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end)
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the detectron2 logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (CfgNode): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() setup_logger(output_dir, distributed_rank=rank, name="fvcore") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format(rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info( "Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read() ) ) logger.info("Running with full config:\n{}".format(cfg)) if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng() # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __call__(self, feats, labels): feats = F.normalize(feats, dim=1) all_feats, all_labels = None, None if comm.get_world_size() > 1: all_feats = concat_all_gather(feats) all_labels = concat_all_gather(labels) else: all_feats = feats all_labels = labels batch_size = feats.size(0) sim_mat = torch.mm(feats, all_feats.t()) losses = [] rank = comm.get_rank() for i in range(batch_size): pos_idxs = (all_labels == labels[i]) pos_idxs[rank * batch_size + i] = False pos_pair_ = sim_mat[i][pos_idxs] neg_pair_ = sim_mat[i][all_labels != labels[i]] if self.hard_mining: neg_pair = neg_pair_[neg_pair_ + self.margin > min(pos_pair_)] pos_pair = pos_pair_[pos_pair_ - self.margin < max(neg_pair_)] else: neg_pair = neg_pair_ pos_pair = pos_pair_ if len(pos_pair) < 1 or len(neg_pair) < 1: continue if len(pos_pair) > 1: pos_pair = pos_pair[random.randint(0, len(pos_pair) - 1)] else: pos_pair = pos_pair[0] # print('neg_pair:', neg_pair) # print('pos_pair:', pos_pair) loss = torch.log( 1 + torch.sum(torch.exp(self.scale * (neg_pair - pos_pair)))) losses.append(loss) return sum(losses) / len(losses)
def __init__(self, data_source: List, size: int = None, seed: Optional[int] = None, callback_get_label: Callable = None): self.data_source = data_source # consider all elements in the dataset self.indices = list(range(len(data_source))) # if num_samples is not provided, draw `len(indices)` samples in each iteration self._size = len(self.indices) if size is None else size self.callback_get_label = callback_get_label # distribution of classes in the dataset label_to_count = {} for idx in self.indices: label = self._get_label(data_source, idx) label_to_count[label] = label_to_count.get(label, 0) + 1 # weight for each sample weights = [1.0 / label_to_count[self._get_label(data_source, idx)] for idx in self.indices] self.weights = torch.DoubleTensor(weights) if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, data_source: str, batch_size: int, num_instances: int, delete_rem: bool, seed: Optional[int] = None, cfg = None): self.data_source = data_source self.batch_size = batch_size self.num_instances = num_instances self.num_pids_per_batch = batch_size // self.num_instances self.delete_rem = delete_rem self.index_pid = defaultdict(list) self.pid_domain = defaultdict(list) self.pid_index = defaultdict(list) for index, info in enumerate(data_source): domainid = info[3]['domains'] if cfg.DATALOADER.CAMERA_TO_DOMAIN: pid = info[1] + str(domainid) else: pid = info[1] self.index_pid[index] = pid # self.pid_domain[pid].append(domainid) self.pid_domain[pid] = domainid self.pid_index[pid].append(index) self.pids = list(self.pid_index.keys()) self.domains = list(self.pid_domain.values()) self.num_identities = len(self.pids) self.num_domains = len(set(self.domains)) self.batch_size //= self.num_domains self.num_pids_per_batch //= self.num_domains if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() val_pid_index = [len(x) for x in self.pid_index.values()] min_v = min(val_pid_index) max_v = max(val_pid_index) hist_pid_index = [val_pid_index.count(x) for x in range(min_v, max_v+1)] num_print = 5 for i, x in enumerate(range(min_v, min_v+min(len(hist_pid_index), num_print))): print('dataset histogram [bin:{}, cnt:{}]'.format(x, hist_pid_index[i])) print('...') print('dataset histogram [bin:{}, cnt:{}]'.format(max_v, val_pid_index.count(max_v))) val_pid_index_upper = [] for x in val_pid_index: v_remain = x % self.num_instances if v_remain == 0: val_pid_index_upper.append(x) else: if self.delete_rem: if x < self.num_instances: val_pid_index_upper.append(x - v_remain + self.num_instances) else: val_pid_index_upper.append(x - v_remain) else: val_pid_index_upper.append(x - v_remain + self.num_instances) cnt_domains = [0 for x in range(self.num_domains)] for val, index in zip(val_pid_index_upper, self.domains): cnt_domains[index] += val self.max_cnt_domains = max(cnt_domains) self.total_images = self.num_domains * (self.max_cnt_domains - (self.max_cnt_domains % self.batch_size) - self.batch_size)