def __init__( self, module: nn.Module, sharded_optimizer: Union[OSS, List[OSS]], process_group: Any = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, ): super().__init__() self.module = module self.sharded_optimizers = [sharded_optimizer] if isinstance( sharded_optimizer, OSS) else sharded_optimizer self.enable_broadcast_buffers = broadcast_buffers # Handle a no_sync() context which prevents the gradient synchronization, # accumulate in place self.should_accumulate_grads = False # Communication related attributes self.process_group = process_group if process_group is not None else dist.group.WORLD self.world_size = dist.get_world_size(self.process_group) self.reference_global_rank = OSS.get_global_rank( self.process_group, 0) # picking rank 0 as the reference self.rank = dist.get_rank(self.process_group) self.global_rank = OSS.get_global_rank(self.process_group, self.rank) # Expose some of the PytorchDDP attributes, some frameworks rely on them. # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # device_id related logic is not present, this is not handled devices = {p.device for p in self.module.parameters()} self.is_multi_device_module = len(devices) > 1 self.device = list(devices)[0] distinct_device_types = { p.device.type for p in self.module.parameters() } assert len(distinct_device_types) == 1, ( "ShardedDataParallel's input module must be on " "the same type of devices, but input module parameters are located on {} different device types." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] # Scafolding to be able to reduce the grads during the BW pass # several optimizers can be present each working on seperate parameter sets, # we build an iterator which goes through all the parameters involved globally self._param_iterator = chain(*[ optim.should_bucket_param.keys() for optim in self.sharded_optimizers ]) self._grad_to_be_reduced = [True for _ in self._param_iterator] self._grad_accs: List[Callable] = [] self._setup_backward_hooks() # Make sure that all ranks start with the same model if sync_models_at_startup: self._sync_params_and_buffers()
def reduce(*_: Any) -> None: # Skip gradient reduction, do not alter status flags if not self.should_accumulate_grads and self._grad_to_be_reduced[ index]: assert param.grad is not None, "Reducing gradients during backward pass, cannot be None" if not self._bucket_flush_callback_set: Variable._execution_engine.queue_callback( self._flush_buckets) self._bucket_flush_callback_set = True # Make sure that this is not fired twice self._grad_to_be_reduced[index] = False param.grad.mul_(self.world_size_scaling) if self.reduce_fp16: param.grad.data = param.grad.data.half() # Future work includes clearing up the buffer if possible def cleanup() -> None: if dst_rank != self.global_rank: param.grad = None else: assert param.grad is not None param.grad.data = param.grad.data.to( dtype=param.dtype) # Async reduce for this buffer, log the future dst_global_rank = OSS.get_global_rank( self.process_group, dst_rank) self._work_handles.append( Workhandle( handle=dist.reduce(tensor=param.grad.data, dst=dst_global_rank, group=self.process_group, async_op=True), callback=cleanup, )) self._reduced_grads += 1 # Opportunistically try to empty the queue self._try_consume_work_handle() # If all the reduce operations have been called, # make sure that all the asynchronous calls have concluded before moving on # and execute the delayed actions (release gradients, unroll the buckets) if self._reduced_grads == self._reduced_grads_max: self._consume_work_handles()
def __init__( self, module: nn.Module, sharded_optimizer: Union[OSS, List[OSS]], process_group: Any = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, ): super().__init__() self.module = module self.sharded_optimizers = [sharded_optimizer] if isinstance( sharded_optimizer, OSS) else sharded_optimizer self.enable_broadcast_buffers = broadcast_buffers # Handle a no_sync() context which prevents the gradient synchronization, # accumulate in place self.should_accumulate_grads = False # Communication related attributes self.process_group = process_group if process_group is not None else dist.group.WORLD self.world_size_scaling = 1.0 / dist.get_world_size( self.process_group) # > 0 self.reference_global_rank = OSS.get_global_rank( self.process_group, 0) # picking rank 0 as the reference self.rank = dist.get_rank(self.process_group) self.global_rank = OSS.get_global_rank(self.process_group, self.rank) # Expose some of the PytorchDDP attributes, some frameworks rely on them. # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # device_id related logic is not present, this is not handled devices = {p.device for p in self.module.parameters()} self.is_multi_device_module = len(devices) > 1 self.device = list(devices)[0] distinct_device_types = { p.device.type for p in self.module.parameters() } assert len(distinct_device_types) == 1, ( "ShardedDataParallel's input module must be on " "the same type of devices, but input module parameters are located on {} different device types." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] # Scafolding to be able to reduce the grads during the BW pass # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks # - we build an iterator which goes through all the parameters involved globally all_param_iterator = chain(*[ sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers ]) self._grad_to_be_reduced = [ True for _ in filter(lambda x: x.requires_grad, all_param_iterator) ] # - keep track of the grads which have already been reduced self._reduced_grads: Dict[OSS, int] = {} self._reduced_grads_max = { o: len(o.param_to_rank.values()) for o in self.sharded_optimizers } self._clear_counters() # - setup backward hooks which will be called by Torch's autograd in due time self._grad_accs: List[Callable] = [] self._setup_backward_hooks() # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self.module) # Make sure that all ranks start with the same model if sync_models_at_startup: self._sync_params_and_buffers()
def __init__( self, module: nn.Module, sharded_optimizer: Union[OSS, List[OSS]], process_group: Any = None, broadcast_buffers: bool = True, sync_models_at_startup: bool = True, reduce_buffer_size: int = 2**23, auto_refresh_trainable: bool = True, reduce_fp16: bool = False, ): super().__init__() self.module = module self.sharded_optimizers = [ sharded_optimizer ] if not isinstance(sharded_optimizer, list) else sharded_optimizer self.enable_broadcast_buffers = broadcast_buffers self.auto_refresh_trainable = auto_refresh_trainable self.reduce_fp16 = reduce_fp16 if reduce_buffer_size > 0 and reduce_fp16: self.reduce_fp16 = False logging.warning( "fp16 gradient reduction is not compatible with reduction buffers, which are requested. fp16 grad reduction is deactivated." ) # Handle a no_sync() context which prevents the gradient synchronization, # accumulate in place self.should_accumulate_grads = False self.accumulate_grads_flipped = False # Communication related attributes self.process_group = process_group if process_group is not None else dist.group.WORLD self.backend = dist.get_backend(self.process_group) self.world_size_scaling = 1.0 / dist.get_world_size( self.process_group) # > 0 self.reference_global_rank = OSS.get_global_rank( self.process_group, 0) # picking rank 0 as the reference self.rank = dist.get_rank(self.process_group) self.global_rank = OSS.get_global_rank(self.process_group, self.rank) self._local_to_global_rank = [ OSS.get_global_rank(self.process_group, i) for i in range(dist.get_world_size(self.process_group)) ] # Expose some of the PytorchDDP attributes, some frameworks rely on them. # See https://pytorch.org/docs/stable/_modules/torch/nn/parallel/distributed.html#DistributedDataParallel # device_id related logic is not present, this is not handled devices = {p.device for p in self.module.parameters()} self.is_multi_device_module = len(devices) > 1 self.device = list(devices)[0] distinct_device_types = { p.device.type for p in self.module.parameters() } assert len(distinct_device_types) == 1, ( "ShardedDataParallel's input module must be on " "the same type of devices, but input module parameters are located on {} different device types." ).format(distinct_device_types) self.device_type = list(distinct_device_types)[0] # Scafolding to be able to reduce the grads during the BW pass # several optimizers can be present each working on seperate parameter set which is spread across multiple ranks # - we build an iterator which goes through all the parameters involved globally self._all_params = list( chain(*[ sum([sum(p, []) for p in optim.per_device_params.values()], []) for optim in self.sharded_optimizers ])) self._trainable_params: List[torch.Tensor] = [] self._grad_to_be_reduced: List[bool] = [] self._trainable_param_to_rank: Dict[torch.Tensor, int] = {} self._reference_trainable_mask = list(map(_trainable, self._all_params)) # - setup buckets and tensor views model_size = sum([p.numel() for p in self.module.parameters()]) self.buffer_max_size = min(reduce_buffer_size, model_size) if dist.get_world_size(self.process_group) == 1: self.buffer_max_size = 0 logging.info( "Training is not really distributed, single rank. Deactivating buckets" ) logging.info( "ShardedDDP bucket size: {:.2f}M parameters, model size {:.2f}M parameters" .format(self.buffer_max_size / 2**20, model_size / 2**20)) self.use_buckets = self.buffer_max_size > 0 self.buckets: Dict[torch.device, Dict[int, GradBucket]] = {} self._should_bucket_grad: List[bool] = [] self._bucket_list: List[GradBucket] = [] # - setup backward hooks which will be called by Torch's autograd in due time self._grad_accs: List[Callable] = [] self._grad_hooks: List[Any] = [] self._manual_reduce: List[Callable] = [] # passing a handle to torch.nn.SyncBatchNorm layer self._passing_sync_batchnorm_handle(self.module) # Make sure that all ranks start with the same model if sync_models_at_startup: self._sync_params_and_buffers() self._work_handles: Deque[Workhandle] = deque() self._bucket_flush_callback_set = False
def _reduce_grads_task(buffers: List[torch.Tensor], per_rank_params: List[List[Parameter]], group: Any, self_rank: int, world_size: int) -> None: """Helper to reduce a list of params. The params are sorted by size, smallest first, which allows for an opportunistic bucketing. NOTE: All param gradients are assumed to exist""" buffer_size = buffers[0].numel() bucket_requests = [] requests = [] for (rank, params), buffer in zip(enumerate(per_rank_params), buffers): # All the params are sorted per rank and per increasing size if len(params) == 0: continue for p in params: if p.grad is None: p.grad = torch.zeros_like(p) global_rank = OSS.get_global_rank(group, rank) # Copy small gradients into per-GPU buffers and then async reduce i_bucketed = 0 # the number of tensors packed in the buffer offset = 0 # Since all the parameters are already sorted per increasing size, we only need to consider the first ones. while i_bucketed < len(params) and offset + params[ i_bucketed].numel() < buffer_size: end = offset + params[i_bucketed].numel() buffer[offset:end].copy_( params[i_bucketed].grad.data.view(-1)) # type: ignore offset = end i_bucketed += 1 if i_bucketed > 0: buffer.div_(world_size) # type: ignore bucket_requests.append(( dist.reduce(tensor=buffer, dst=global_rank, group=group, async_op=True), # type: ignore rank, )) # Directly reduce the other grads for p in params[i_bucketed:]: p.grad = cast(Tensor, p.grad) if p.grad.requires_grad: raise RuntimeError( "DistributedDataParallel only works with gradients that don't require grad" ) p.grad.div_(world_size) # type: ignore requests.append( dist.reduce(tensor=p.grad, dst=global_rank, group=group, async_op=True)) # type: ignore # Unroll the initial packed small gradients, as soon as possible for future, rank in bucket_requests: future.wait() if rank == self_rank: i_bucketed = 0 # the number of tensors packed in the buffer offset = 0 params = per_rank_params[rank] buffer = buffers[rank] while i_bucketed < len(params) and offset + params[ i_bucketed].numel() < buffer_size: end = offset + params[i_bucketed].numel() params[i_bucketed].grad.data.copy_( buffer[offset:end].view_as( params[i_bucketed])) # type: ignore offset = end i_bucketed += 1 # Make sure that we're done with this device before moving on and cleaning the unused params _ = list(map(lambda x: x.wait(), requests))
def _setup_bucket_strategy(self) -> None: """Devise a bucketing strategy on a per-rank ownership level. These buckets will not be sharded, since the gradients would be re-allocated during the backward in that case. This method can be a slow for big models, but it it not typically called often (not for every forward for instance) """ # A priori, one reduce call per param self._reduced_grads_max = len(self._trainable_params) if not self.use_buckets: return # Devise the bucketing strategy. Parameters are already sorted, in that: # - these are only the trainable parameters, so they should produce grads # - they are sorted by increasing size self.buckets = {} for param in self._trainable_params: device = param.device dst_rank = self._trainable_param_to_rank[param] if param.device not in self.buckets.keys(): self.buckets[param.device] = [ Bucket(buffer=torch.zeros(self.buffer_max_size, dtype=param.dtype, device=device)) for _ in range(dist.get_world_size(self.process_group)) ] bucket = self.buckets[device][dst_rank] bucket.destination = OSS.get_global_rank(self.process_group, dst_rank) # Criteria to decide whether this parameter is to be bucketed or not: # - enough room in the bucket if (bucket.fill + param.numel()) < self.buffer_max_size: self._should_bucket_grad.append(True) # This parameter gradients becomes a view of the bucket fill_next = bucket.fill + param.numel() if param.grad is None: # will be overwritten just below, see next line param.grad = torch.zeros_like(param) param.grad.data = bucket.buffer[bucket.fill:fill_next].view_as( param.data) bucket.fill = fill_next # Update the bucket self._reduced_grads_max -= 1 # one less reduce call per bucketed grad self.buckets[device][dst_rank].max_params_checked_in += 1 else: self._should_bucket_grad.append(False) self._bucket_list = list( chain(*[self.buckets[device] for device in self.buckets.keys()])) # Resize the buckets to remove lost space in the end for bucket in self._bucket_list: bucket.buffer.resize_(bucket.fill) bucket.sent = True if bucket.max_params_checked_in > 0: self._reduced_grads_max += 1 # one reduce call per bucket