def train(cfg, args, model, device, distributed): optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = comm.get_rank() == 0 checkpointer = DetectronCheckpointer(cfg, model, optimizer, scheduler, output_dir, save_to_disk) if cfg.MODEL.BACKBONE.CONV_BODY is "DLA-34-DCN": ckpt = cfg.MODEL.WEIGHT if args.ckpt is None else args.ckpt extra_checkpoint_data = checkpointer.load(ckpt) arguments.update(extra_checkpoint_data) elif args.ckpt is not None: extra_checkpoint_data = checkpointer.load(args.ckpt) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train(cfg, distributed, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments)
def train(cfg, model, device, distributed): optimizer = make_optimizer(cfg, model) scheduler = make_lr_scheduler(cfg, optimizer) arguments = {} arguments["iteration"] = 0 output_dir = cfg.OUTPUT_DIR save_to_disk = comm.get_rank() == 0 checkpointer = DetectronCheckpointer( cfg, model, optimizer, scheduler, output_dir, save_to_disk ) extra_checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT) arguments.update(extra_checkpoint_data) data_loader = make_data_loader( cfg, is_train=True, ) checkpoint_period = cfg.SOLVER.CHECKPOINT_PERIOD do_train( cfg, distributed, model, data_loader, optimizer, scheduler, checkpointer, device, checkpoint_period, arguments )
def default_setup(cfg, args): output_dir = cfg.OUTPUT_DIR if output_dir: mkdir(output_dir) rank = comm.get_rank() logger = setup_logger(output_dir, rank) logger.info("Using {} GPUs".format(args.num_gpus)) logger.info("Collecting environment info") logger.info("\n" + collect_env_info()) logger.info(args) logger.info("Loaded configuration file {}".format(args.config_file)) with open(args.config_file, "r") as cf: config_str = "\n" + cf.read() logger.info(config_str) logger.info("Running with config:\n{}".format(cfg)) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK
def __init__(self, size: int): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end)
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, dataset_dicts, repeat_thresh, shuffle=True, seed=None): """ Args: dataset_dicts (list[dict]): annotations in Detectron2 dataset format. repeat_thresh (float): frequency threshold below which data is repeated. shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() # Get fractional repeat factors and split into whole number (_int_part) # and fractional (_frac_part) parts. rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh) self._int_part = torch.trunc(rep_factors) self._frac_part = rep_factors - self._int_part