def __init__(self, cfg): """ Args: cfg (CfgNode): """ logger = logging.getLogger("vidgen") if not logger.isEnabledFor( logging.INFO): # setup_logger is not called for d2 setup_logger() # Assume these objects must be constructed in this order. self.model = self.build_model(cfg) self.optimizers, self.checkpointers = self.model.configure_optimizers_and_checkpointers( ) self.data_loader, dataset_len = self.build_train_loader(cfg) self._data_loader_iter = iter(self.data_loader) # For training, wrap with DDP. But don't need this for inference. if comm.get_world_size() > 1: self.model.wrap_parallel(device_ids=[comm.get_local_rank()], broadcast_buffers=False) super().__init__(cfg) self.model.train() self.start_iter = 0 self.max_iter = cfg.SOLVER.MAX_ITER self.gan_mode_on = cfg.GAN_MODE_ON self.supervised_max_iter = cfg.SOLVER.SUPERVISED_MAX_ITER self.d_update_ratio = cfg.SOLVER.D_UPDATE_RATIO self.d_init_iters = cfg.SOLVER.D_INIT_ITERS self.cfg = cfg self.register_hooks(self.build_hooks()) self.accumulation_steps = cfg.SOLVER.ACCUMULATION_STEPS
def _straight_through(self, z_e_x): z_e_x_ = z_e_x.permute(0, 2, 3, 1).contiguous() z_q_x_, indices = vq_st(z_e_x_, self.embedding.weight.detach()) z_q_x = z_q_x_.permute(0, 3, 1, 2).contiguous() if self.ema: # Use EMA to update the embedding vectors with torch.no_grad(): device = indices.device size = torch.zeros_like(self.running_size, dtype=indices.dtype, device=device) size.index_add_(dim=0, index=indices, source=torch.ones_like(indices, device=device)) if comm.get_world_size() > 1: size = AllReduce.apply(size) self.running_size.data.mul_(self.decay).add_( 1 - self.decay, size) sum = torch.zeros_like(self.running_sum, dtype=z_e_x_.dtype, device=device) b, h, w, c = z_e_x_.size() sum.index_add_(dim=0, index=indices, source=z_e_x_.view(b * h * w, c)) if comm.get_world_size() > 1: sum = AllReduce.apply(sum) self.running_sum.data.mul_(self.decay).add_( 1 - self.decay, sum) n = self.running_size.sum() size_ = (self.running_size + self.eps) / (n + self.K * self.eps) * n self.embedding.weight.data.copy_(self.running_sum / size_.unsqueeze(1)) z_q_x_bar_flatten = torch.index_select(self.embedding.weight, dim=0, index=indices) z_q_x_bar_ = z_q_x_bar_flatten.view_as(z_e_x_) z_q_x_bar = z_q_x_bar_.permute(0, 3, 1, 2).contiguous() return z_q_x, z_q_x_bar
def default_setup(cfg, args): """ Perform some basic common setups at the beginning of a job, including: 1. Set up the vidgen logger 2. Log basic information about environment, cmdline arguments, and config 3. Backup the config to the output directory Args: cfg (CfgNode): the full config to be used args (argparse.NameSpace): the command line arguments to be logged """ output_dir = cfg.OUTPUT_DIR if comm.is_main_process() and output_dir: PathManager.mkdirs(output_dir) rank = comm.get_rank() setup_logger(output_dir, distributed_rank=rank, name="fvcore") logger = setup_logger(output_dir, distributed_rank=rank) logger.info("Rank of current process: {}. World size: {}".format( rank, comm.get_world_size())) logger.info("Environment info:\n" + collect_env_info()) logger.info("Command line arguments: " + str(args)) if hasattr(args, "config_file") and args.config_file != "": logger.info("Contents of args.config_file={}:\n{}".format( args.config_file, PathManager.open(args.config_file, "r").read())) logger.info("Running with full config:\n{}".format(cfg)) if comm.is_main_process() and output_dir: # Note: some of our scripts may expect the existence of # config.yaml in output directory path = os.path.join(output_dir, "config.yaml") with PathManager.open(path, "w") as f: f.write(cfg.dump()) logger.info("Full config saved to {}".format(os.path.abspath(path))) # make sure each worker has a different, yet deterministic seed if specified seed_all_rng(None if cfg.SEED < 0 else cfg.SEED + rank) # cudnn benchmark has large overhead. It shouldn't be used considering the small size of # typical validation set. if not (hasattr(args, "eval_only") and args.eval_only): torch.backends.cudnn.benchmark = cfg.CUDNN_BENCHMARK # I don't understand fully this line, but it helps with large batch sizes and reduces data loading time torch.multiprocessing.set_sharing_strategy('file_system')
def __init__(self, size: int, n_samples=0): """ Args: size (int): the total number of data of the underlying dataset to sample from """ self._size = size assert size > 0 self._rank = comm.get_rank() self._world_size = comm.get_world_size() shard_size = (self._size - 1) // self._world_size + 1 begin = shard_size * self._rank end = min(shard_size * (self._rank + 1), self._size) self._local_indices = range(begin, end) if n_samples > 0: self._local_indices = np.random.choice(self._local_indices, n_samples, replace=False)
def forward(self, input): if comm.get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) if self._stats_mode == "": assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.' vec = torch.cat([mean, meansqr], dim=0) vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) momentum = self.momentum else: if B == 0: vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype) vec = vec + input.sum( ) # make sure there is gradient w.r.t input else: vec = torch.cat([ mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype) ], dim=0) vec = AllReduce.apply(vec * B) total_batch = vec[-1].detach() momentum = total_batch.clamp( max=1) * self.momentum # no update if total_batch is 0 total_batch = torch.max( total_batch, torch.ones_like(total_batch)) # avoid div-by-zero mean, meansqr, _ = torch.split(vec / total_batch, C) var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) self.running_mean += momentum * (mean.detach() - self.running_mean) self.running_var += momentum * (var.detach() - self.running_var) return input * scale + bias
def __init__(self, size: int, shuffle: bool = True, seed: Optional[int] = None): """ Args: size (int): the total number of data of the underlying dataset to sample from shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._size = size assert size > 0 self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size()
def __init__(self, dataset_dicts, repeat_thresh, shuffle=True, seed=None): """ Args: dataset_dicts (list[dict]): annotations in vidgen dataset format. repeat_thresh (float): frequency threshold below which data is repeated. shuffle (bool): whether to shuffle the indices or not seed (int): the initial seed of the shuffle. Must be the same across all workers. If None, will use a random seed shared among workers (require synchronization among all workers). """ self._shuffle = shuffle if seed is None: seed = comm.shared_random_seed() self._seed = int(seed) self._rank = comm.get_rank() self._world_size = comm.get_world_size() # Get fractional repeat factors and split into whole number (_int_part) # and fractional (_frac_part) parts. rep_factors = self._get_repeat_factors(dataset_dicts, repeat_thresh) self._int_part = torch.trunc(rep_factors) self._frac_part = rep_factors - self._int_part
def build_train_loader(cfg, mapper=None): """ A data loader is created by the following steps: 1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. 2. Start workers to work on the dicts. Each worker will: * Map each metadata dict into another format to be consumed by the model. * Batch them by simply putting dicts into a list. The batched ``list[mapped_dict]`` is what this dataloader will return. Args: cfg (CfgNode): the config mapper (callable): a callable which takes a sample (dict) from dataset and returns the format to be consumed by the model. By default it will be `DatasetMapper(cfg, True)`. Returns: an infinite iterator of training data """ num_workers = get_world_size() images_per_batch = cfg.SOLVER.IMS_PER_BATCH assert ( images_per_batch % num_workers == 0 ), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( images_per_batch, num_workers ) assert ( images_per_batch >= num_workers ), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( images_per_batch, num_workers ) images_per_worker = images_per_batch // num_workers dataset_dicts = get_dataset_dicts(cfg.DATASETS.TRAIN) dataset = DatasetFromList(dataset_dicts, copy=False) if mapper is None: mapper = DatasetMapper(cfg, True) dataset = MapDataset(dataset, mapper) sampler_name = cfg.DATALOADER.SAMPLER_TRAIN logger = logging.getLogger(__name__) logger.info("Using training sampler {}".format(sampler_name)) if sampler_name == "TrainingSampler": sampler = samplers.TrainingSampler(len(dataset)) elif sampler_name == "RepeatFactorTrainingSampler": sampler = samplers.RepeatFactorTrainingSampler( dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD ) else: raise ValueError("Unknown training sampler: {}".format(sampler_name)) batch_sampler = torch.utils.data.sampler.BatchSampler( sampler, images_per_worker, drop_last=True ) # drop_last so the batch always have the same size data_loader = torch.utils.data.DataLoader( dataset, num_workers=cfg.DATALOADER.NUM_WORKERS, batch_sampler=batch_sampler, collate_fn=trivial_batch_collator, worker_init_fn=worker_init_reset_seed, ) return data_loader, len(dataset)