def _all_gather_list_sync( self, logging_outputs: List[Dict[str, Any]], *extra_stats_to_sum, ignore=False, ): """ Sync logging outputs across workers. all_gather_list_sync is suitable when logging outputs are complex types. """ if self.tpu: raise NotImplementedError if ignore: logging_outputs = [] results = list( zip( *distributed_utils.all_gather_list( [logging_outputs] + list(extra_stats_to_sum), max_size=getattr(self.cfg.common, "all_gather_list_size", 16384), group=self.data_parallel_process_group, ) ) ) logging_outputs, extra_stats_to_sum = results[0], results[1:] logging_outputs = list(chain.from_iterable(logging_outputs)) extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] return logging_outputs, extra_stats_to_sum
def __init__(self, cfg: FairseqConfig, task, model, criterion, quantizer=None): if isinstance(cfg, Namespace): logger.warning( "argparse.Namespace configuration is deprecated! Automatically converting to OmegaConf" ) cfg = convert_namespace_to_omegaconf(cfg) self.cfg = cfg self.task = task # catalog shared parameters shared_params = _catalog_shared_params(model) self.tpu = cfg.common.tpu self.cuda = torch.cuda.is_available() and not cfg.common.cpu and not self.tpu if self.cuda: self.device = torch.device("cuda") elif self.tpu: self.device = utils.get_tpu_device() else: self.device = torch.device("cpu") # copy model and criterion to current device/dtype self._criterion = criterion self._model = model if cfg.common.fp16: self._criterion = self._criterion.half() self._model = self._model.half() elif cfg.common.bf16: self._criterion = self._criterion.to(dtype=torch.bfloat16) self._model = self._model.to(dtype=torch.bfloat16) if ( not cfg.distributed_training.pipeline_model_parallel # the DistributedFairseqModel wrapper will handle moving to device, # so only handle cases which don't use the wrapper and not self.use_distributed_wrapper ): self._criterion = self._criterion.to(device=self.device) self._model = self._model.to(device=self.device) self.pipeline_model_parallel = cfg.distributed_training.pipeline_model_parallel self.last_device = None if self.cuda and self.pipeline_model_parallel: self.last_device = torch.device( cfg.distributed_training.pipeline_devices[-1] ) # check that shared parameters are preserved after device transfer for shared_param in shared_params: ref = _get_module_by_path(self._model, shared_param[0]) for path in shared_param[1:]: logger.info( "detected shared parameter: {} <- {}".format(shared_param[0], path) ) _set_module_by_path(self._model, path, ref) self._dummy_batch = None # indicates we don't have a dummy batch at first self._lr_scheduler = None self._num_updates = 0 self._num_xla_compiles = 0 # for TPUs self._optim_history = None self._optimizer = None self._warn_once = set() self._wrapped_criterion = None self._wrapped_model = None # TODO(myleott): support tpu if self.cuda and self.data_parallel_world_size > 1: self._grad_norm_buf = torch.cuda.DoubleTensor(self.data_parallel_world_size) else: self._grad_norm_buf = None self.quantizer = quantizer if self.quantizer is not None: self.quantizer.set_trainer(self) # get detailed cuda environment if self.cuda: self.cuda_env = utils.CudaEnvironment() if self.data_parallel_world_size > 1: self.cuda_env_arr = distributed_utils.all_gather_list( self.cuda_env, group=distributed_utils.get_global_group() ) else: self.cuda_env_arr = [self.cuda_env] if self.data_parallel_rank == 0: utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) else: self.cuda_env = None self.cuda_env_arr = None metrics.log_start_time("wall", priority=790, round=0) self._start_time = time.time() self._previous_training_time = 0 self._cumulative_training_time = None
def _test_all_gather_list_rank_tensor(rank, group): obj = torch.tensor([rank]) objs = dist_utils.all_gather_list(obj, group) for i, obj in enumerate(objs): assert obj.item() == i
def _test_all_gather_list_equality(ref_obj, rank, group): objs = dist_utils.all_gather_list(ref_obj, group) for obj in objs: assert objects_are_equal(ref_obj, obj)