def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end)
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the Fs3c logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (CfgNode): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() setup_logger(output_dir, distributed_rank=rank, name="fvcore") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format( rank, comm.get_world_size())) if not cfg.MUTE_HEADER: logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file"): logger.info("Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read())) if not cfg.MUTE_HEADER: logger.info("Running with full config:\n{}".format(cfg)) if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, dataset_dicts, repeat_thresh, shuffle=True, seed=None): """ Args: dataset_dicts (list[dict]): annotations in Detectron2 dataset format. repeat_thresh (float): frequency threshold below which data is repeated. shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() # Get fractional repeat factors and split into whole number (_int_part) # and fractional (_frac_part) parts. rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh) self._int_part = torch.trunc(rep_factors) self._frac_part = rep_factors - self._int_part
def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) assert input.shape[0] > 0, "SyncBatchNorm does not support empty inputs" C = input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return input * scale + bias
def __init__(self, cfg): """ Args: cfg (CfgNode): """ # Assume these objects must be constructed in this order. model = self.build_model(cfg) optimizer = self.build_optimizer(cfg, model) data_loader = self.build_train_loader(cfg) if cfg.SSL: data_loader_ssl = self.build_ssl_loader(cfg) else: data_loader_ssl = None # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: model = DistributedDataParallel(model, device_ids=[comm.get_local_rank()], broadcast_buffers=False, find_unused_parameters=True) super().__init__(model, data_loader, optimizer, data_loader_ssl) self.scheduler = self.build_lr_scheduler(cfg, optimizer) # Assume no other objects need to be checkpointed. # We can later make it checkpoint the stateful hooks self.checkpointer = DetectionCheckpointer( # Assume you want to save checkpoints together with logs/statistics model, cfg.OUTPUT_DIR, optimizer=optimizer, scheduler=self.scheduler, ) self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.cfg = cfg self.register_hooks(self.build_hooks())
def build_ssl_train_loader(cfg, mapper=None): """ A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. 2. Start workers to work on the dicts. Each worker will: * Map each metadata dict into another format to be consumed by the model. * Batch them by simply putting dicts into a list. The batched ``list[mapped_dict]`` is what this dataloader will return. Args: cfg (CfgNode): the config mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be `DatasetMapper(cfg, True)`. Returns: a torch DataLoader object """ num_workers = get_world_size() images_per_batch = cfg.SOLVER.IMS_PER_BATCH assert ( images_per_batch % num_workers == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( images_per_batch, num_workers) assert ( images_per_batch >= num_workers ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( images_per_batch, num_workers) images_per_worker = images_per_batch // num_workers dataset_dicts = get_detection_dataset_dicts( cfg.DATASETS.SSL, filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, ) dataset = DatasetFromList(dataset_dicts, copy=False) if mapper is None: mapper = DatasetMapper(cfg, True, True) dataset = MapDataset(dataset, mapper) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": sampler = samplers.RepeatFactorTrainingSampler( dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) batch_sampler = build_batch_data_sampler(sampler, images_per_worker) data_loader = torch.utils.data.DataLoader( dataset, num_workers=cfg.DATALOADER.NUM_WORKERS, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, worker_init_fn=worker_init_reset_seed, ) return data_loader