def backward(ctx, grad_output): if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: rank = dist.get_rank(group=ctx.group) world_size = dist.get_world_size(group=ctx.group) out_size = list(grad_output.size()) if out_size[0] % world_size != 0: raise RuntimeError( f'Tensor with dimensions: {out_size} does ' f'not have first dimension divisible by world_size: {world_size}' ) out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype) dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) else: raise RuntimeError("Backend not supported!") return (None, gx, None)
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor, output: torch.Tensor): r""" This FSDP communication hook implements ``reduce_scatter`` algorithm for sharded FSDP strategies and a necessary pre- and post-division of gradients. Args: state (DefaultState): State information, configures pre- and post-division factors. grad (torch.Tensor): An unsharded gradient for the local batch that needs to be communicated across ranks. output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``. """ # Average grad by pre-division factor. if state.gradient_predivide_factor > 1: grad.div_(state.gradient_predivide_factor) dist._reduce_scatter_base(output, grad, group=state.process_group) # Average grad's shard by post-division factor. if state.gradient_postdivide_factor > 1: output.div_(state.gradient_postdivide_factor)
def reduce_scatter_base(self, collectiveArgs, retFlag=False, pair=False): retObj = dist._reduce_scatter_base( output=collectiveArgs.opTensor, input=collectiveArgs.ipTensor, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) # synchronicity is maintained in runColl if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(retObj) if retFlag: return retObj
def flush(self) -> None: """Flush content of the bucket.""" if self.offset == 0: assert len(self.callbacks) == 0 return # reduce-scatter bucket if hasattr(dist, "_reduce_scatter_base"): dist._reduce_scatter_base( # type: ignore self.output_shard[:self.offset], self.data[:, :self.offset].contiguous(), group=self.group) else: dist.reduce_scatter(self.output_shard[:self.offset], list(self.data[:, :self.offset].unbind(0)), group=self.group) # execute post-reduction callbacks for callback_fn in self.callbacks: callback_fn() # reuse input bucket but allocate a fresh output shard self.data[:, :self.offset].zero_() self.offset = 0 self.callbacks.clear() self.output_shard = torch.zeros_like(self.data[0])
def reduce_scatter_async( self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, ) -> None: """ Reduce-scatter a list of tensors asynchronously, so smaller reductions can be bucketed together. The given callback (``callback_fn``) will be called with the reduced result at some later time. Call ``flush()`` to force all queued ops and callbacks to be executed. Note that large inputs will be reduced immediately, and this function may also flush the relevant bucket to make room for ``input_list``. Args: input_list (List[Tensor]): list of tensors to reduce-scatter. List should contain ``group.size()`` tensors and each tensor should have identical shape, dtype and device. group (ProcessGroup): process group for reduction callback_fn (Callable, Optional): callback function to call after the reduction executes. Function will be called with a single argument corresponding to the reduced result. """ world_size = group.size() assert ( len(input_list) == world_size ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) if first_input_size > bucket_shard_size: # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) # input is too big to fit in the bucket, reduce-scatter directly output = torch.zeros_like(input_list[0]) if hasattr(dist, "_reduce_scatter_base"): input_flattened = torch.cat(input_list) dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore else: # fallback dist.reduce_scatter(output, input_list, group=group) if callback_fn is not None: callback_fn(output) return bucket = self._get_bucket(first_input, group) if first_input_size > bucket.data.size(1) - bucket.offset: # not enough space remaining in bucket, flush it now bucket.flush() # copy data from input_list into bucket stacked_input = torch.stack(input_list).view(world_size, first_input_size) offset = bucket.offset bucket.data[:, offset:offset + first_input_size].copy_(stacked_input) bucket.offset += first_input_size # callback will be given the reduced result if callback_fn is not None: result_view = bucket.output_shard[offset:offset + first_input_size].view_as( first_input) bucket.callbacks.append(functools.partial(callback_fn, result_view))
def _post_backward_hook(self, param: Parameter, *unused: Any) -> None: """ At the start of :func:`_post_backward_hook`, ``param.grad`` contains the full gradient for the local batch. The reduce-scatter op will replace ``param.grad`` with a single shard of the summed gradient across all GPUs. This shard will align with the current GPU rank. For example:: before reduce_scatter: param.grad (GPU #0): [1, 2, 3, 4] param.grad (GPU #1): [5, 6, 7, 8] after reduce_scatter: param.grad (GPU #0): [6, 8] # 1+5, 2+6 param.grad (GPU #1): [10, 12] # 3+7, 4+8 The local GPU's ``optim.step`` is responsible for updating a single shard of params, also corresponding to the current GPU's rank. This alignment is created by :func:`_shard_parameters`, which ensures that the local optimizer only sees the relevant parameter shard. """ # First hook callback will see PRE state. If we have multiple params, # then subsequent hook callbacks will see POST state. self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST]) self.training_state = TrainingState_.BACKWARD_POST if param.grad is None: return if param.grad.requires_grad: raise RuntimeError( "FSDP only works with gradients that don't require gradients" ) self._free_full_params([param]) # Switch to local shard after backward. self._use_param_local_shard([param]) # Wait for all work in the current stream to finish, then start the # reductions in post_backward stream. self._streams["post_backward"].wait_stream(torch.cuda.current_stream()) with torch.cuda.stream(self._streams["post_backward"]): orig_grad_data = param.grad.data if self.gradient_predivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. param.grad.div_(self.gradient_predivide_factor) if param._is_sharded: # type: ignore[attr-defined] grad_flatten = torch.flatten(param.grad) chunks = list(grad_flatten.chunk(self.world_size)) num_pad = self.world_size * chunks[0].numel() - param.grad.numel() input_flattened = F.pad(grad_flatten, [0, num_pad]) output = torch.zeros_like(chunks[0]) dist._reduce_scatter_base( output, input_flattened, group=self.process_group ) if self.gradient_postdivide_factor > 1: # Average grad by world_size for consistency with PyTorch DDP. output.div_(self.gradient_postdivide_factor) param.grad.data = output else: # Currently the only way for _is_sharded to be False is if # world_size == 1. This could be relaxed in the future, e.g, # no sharding like PyTorch DDP, in which case grads should be # all-reduced here. assert ( self.world_size == 1 ), "Currently the only way for _is_sharded to be False is \ world_size == 1" # After _post_backward_hook returns, orig_grad_data will eventually # go out of scope, at which point it could otherwise be freed for # further reuse by the main stream while the div/reduce_scatter/copy # are underway in the post_backward stream. See: # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py orig_grad_data.record_stream(self._streams["post_backward"])