def __init__( self, params: _params_t, optim: Type[Optimizer] = SGD, group: Optional[Any] = None, broadcast_buffer_size: int = -1, broadcast_fp16: bool = False, **default: Any, ): # Hold all the model params in the root .param_groups self.in_super_constructor = True super().__init__(params, default) self.in_super_constructor = False # Partition information. lazy evaluation, computed when requested self.__per_device_params: Dict[ torch.device, List[List[Parameter]]] = OrderedDict() # device, rank, params self.__param_rank: Dict[torch.Tensor, int] = {} self._partition_parameters: List[List[dict]] = [] self.__param_to_index: Dict[int, int] = {} self.__local_params: Optional[List[torch.Tensor]] = None # Default empty values + immutables self._optim_defaults = default self._optim_constructor = optim self.group = group if group is not None else dist.group.WORLD self.world_size = dist.get_world_size(self.group) self.backend = dist.get_backend(self.group) self.rank = dist.get_rank(self.group) self.global_rank = get_global_rank(self.group, self.rank) self._local_to_global_rank = [ get_global_rank(self.group, i) for i in range(self.world_size) ] self.broadcast_fp16 = broadcast_fp16 self.buckets: Dict[torch.device, Dict[int, ParamBucket]] = {} self._all_states: List[Dict[str, Any]] = [ ] # Optional consolidated optimizer state self._default_device = torch.device("cpu") # Setup everything which is related to the parameters to be trained # (partition and optimizer for the shard) self.refresh_trainable()
def __init__( self, module: nn.Module, sharded_optimizer: Union[OSS, List[OSS]], process_group: Any = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, reduce_buffer_size: int = 2**23, auto_refresh_trainable: bool = True, reduce_fp16: bool = False, ): super().__init__() # This field needs to be exposed to insure interface parity with DDP self.module = module self._sharded_optimizers = [ sharded_optimizer ] if not isinstance(sharded_optimizer, list) else sharded_optimizer self._enable_broadcast_buffers = broadcast_buffers self._auto_refresh_trainable = auto_refresh_trainable self._reduce_fp16 = reduce_fp16 if reduce_buffer_size > 0 and reduce_fp16: self._reduce_fp16 = False logging.warning( "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated." ) # Handle a no_sync() context which prevents the gradient synchronization, # accumulate in place self._should_accumulate_grads = False self._accumulate_grads_flipped = False # Communication related attributes self._process_group = process_group if process_group is not None else dist.group.WORLD self._backend = dist.get_backend(self._process_group) self._world_size_scaling = 1.0 / dist.get_world_size( self._process_group) # > 0 self._reference_global_rank = get_global_rank( self._process_group, 0) # picking rank 0 as the reference self._rank = dist.get_rank(self._process_group) self._global_rank = get_global_rank(self._process_group, self._rank) self._local_to_global_rank = [ get_global_rank(self._process_group, i) for i in range(dist.get_world_size(self._process_group)) ] # Expose some of the PytorchDDP attributes, some frameworks rely on them. # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # device_id related logic is not present, this is not handled devices = {p.device for p in self.module.parameters()} self.is_multi_device_module = len(devices) > 1 distinct_device_types = { p.device.type for p in self.module.parameters() } assert len(distinct_device_types) == 1, ( "ShardedDataParallel's input module must be on " "the same type of devices, but input module parameters are located on {} different device types." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] # Scafolding to be able to reduce the grads during the BW pass # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks # - we build an iterator which goes through all the parameters involved globally self._all_params = list( chain(*[ sum([sum(p, []) for p in optim._per_device_params.values()], []) for optim in self._sharded_optimizers ])) self._trainable_params: List[torch.Tensor] = [] self._grad_to_be_reduced: List[bool] = [] self._trainable_param_to_rank: Dict[torch.Tensor, int] = {} self._reference_trainable_mask = list(map(_trainable, self._all_params)) # - setup buckets and tensor views model_size = sum([p.numel() for p in self.module.parameters()]) self._buffer_max_size = min(reduce_buffer_size, model_size) if dist.get_world_size(self._process_group) == 1: self._buffer_max_size = 0 logging.info( "Training is not really distributed, single rank. Deactivating buckets" ) logging.info( "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters" .format(self._buffer_max_size / 2**20, model_size / 2**20)) self._use_buckets = self._buffer_max_size > 0 self._buckets: Dict[torch.device, Dict[int, GradBucket]] = {} self._should_bucket_grad: List[bool] = [] self._bucket_list: List[GradBucket] = [] # - setup backward hooks which will be called by Torch's autograd in due time self._grad_accs: List[Callable] = [] self._grad_hooks: List[Any] = [] self._manual_reduce: List[Callable] = [] # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self.module) # Make sure that all ranks start with the same model if sync_models_at_startup: self._sync_params_and_buffers() self._work_handles: Deque[Workhandle] = deque() self._bucket_flush_callback_set = False