示例#1
0
 def backward(ctx, grad_output):
     if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
         rank = dist.get_rank(group=ctx.group)
         world_size = dist.get_world_size(group=ctx.group)
         out_size = list(grad_output.size())
         if out_size[0] % world_size != 0:
             raise RuntimeError(
                 f'Tensor with dimensions: {out_size} does '
                 f'not have first dimension divisible by world_size: {world_size}'
             )
         out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group)
         gx = torch.empty(out_size, device=grad_output.device, dtype=grad_output.dtype)
         dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group)
     else:
         raise RuntimeError("Backend not supported!")
     return (None, gx, None)
示例#2
0
def reduce_scatter_hook(state: DefaultState, grad: torch.Tensor,
                        output: torch.Tensor):
    r"""
    This FSDP communication hook implements ``reduce_scatter`` algorithm for
    sharded FSDP strategies and a necessary pre- and post-division of gradients.

    Args:
        state (DefaultState): State information, configures pre- and post-division factors.
        grad (torch.Tensor): An unsharded gradient for the local batch that needs to be
        communicated across ranks.
        output (torch.Tensor): Stores a single shard of the gradient after ``reduce_scatter``.
    """
    # Average grad by pre-division factor.
    if state.gradient_predivide_factor > 1:
        grad.div_(state.gradient_predivide_factor)
    dist._reduce_scatter_base(output, grad, group=state.process_group)
    # Average grad's shard by post-division factor.
    if state.gradient_postdivide_factor > 1:
        output.div_(state.gradient_postdivide_factor)
    def reduce_scatter_base(self, collectiveArgs, retFlag=False, pair=False):
        retObj = dist._reduce_scatter_base(
            output=collectiveArgs.opTensor,
            input=collectiveArgs.ipTensor,
            group=collectiveArgs.group,
            async_op=collectiveArgs.asyncOp,
        )  # synchronicity is maintained in runColl

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(retObj)

        if retFlag:
            return retObj
 def flush(self) -> None:
     """Flush content of the bucket."""
     if self.offset == 0:
         assert len(self.callbacks) == 0
         return
     # reduce-scatter bucket
     if hasattr(dist, "_reduce_scatter_base"):
         dist._reduce_scatter_base(  # type: ignore
             self.output_shard[:self.offset],
             self.data[:, :self.offset].contiguous(),
             group=self.group)
     else:
         dist.reduce_scatter(self.output_shard[:self.offset],
                             list(self.data[:, :self.offset].unbind(0)),
                             group=self.group)
     # execute post-reduction callbacks
     for callback_fn in self.callbacks:
         callback_fn()
     # reuse input bucket but allocate a fresh output shard
     self.data[:, :self.offset].zero_()
     self.offset = 0
     self.callbacks.clear()
     self.output_shard = torch.zeros_like(self.data[0])
    def reduce_scatter_async(
        self,
        input_list: List[Tensor],
        group: ProcessGroup,
        callback_fn: Optional[Callable] = None,
    ) -> None:
        """
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
        can be bucketed together. The given callback (``callback_fn``) will be
        called with the reduced result at some later time. Call ``flush()`` to
        force all queued ops and callbacks to be executed.

        Note that large inputs will be reduced immediately, and this function
        may also flush the relevant bucket to make room for ``input_list``.

        Args:
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
                should contain ``group.size()`` tensors and each tensor should
                have identical shape, dtype and device.
            group (ProcessGroup): process group for reduction
            callback_fn (Callable, Optional): callback function to call after
                the reduction executes. Function will be called with a single
                argument corresponding to the reduced result.
        """
        world_size = group.size()

        assert (
            len(input_list) == world_size
        ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

        first_input = input_list[0]
        first_input_size = first_input.numel()

        bucket_shard_size = self._get_shard_size(first_input.element_size(),
                                                 world_size)
        if first_input_size > bucket_shard_size:
            # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
            # input is too big to fit in the bucket, reduce-scatter directly
            output = torch.zeros_like(input_list[0])
            if hasattr(dist, "_reduce_scatter_base"):
                input_flattened = torch.cat(input_list)
                dist._reduce_scatter_base(output, input_flattened,
                                          group=group)  # type: ignore
            else:
                # fallback
                dist.reduce_scatter(output, input_list, group=group)
            if callback_fn is not None:
                callback_fn(output)
            return

        bucket = self._get_bucket(first_input, group)
        if first_input_size > bucket.data.size(1) - bucket.offset:
            # not enough space remaining in bucket, flush it now
            bucket.flush()

        # copy data from input_list into bucket
        stacked_input = torch.stack(input_list).view(world_size,
                                                     first_input_size)
        offset = bucket.offset
        bucket.data[:, offset:offset + first_input_size].copy_(stacked_input)
        bucket.offset += first_input_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = bucket.output_shard[offset:offset +
                                              first_input_size].view_as(
                                                  first_input)
            bucket.callbacks.append(functools.partial(callback_fn,
                                                      result_view))
示例#6
0
    def _post_backward_hook(self, param: Parameter, *unused: Any) -> None:
        """
        At the start of :func:`_post_backward_hook`, ``param.grad`` contains the
        full gradient for the local batch. The reduce-scatter op will replace
        ``param.grad`` with a single shard of the summed gradient across all
        GPUs. This shard will align with the current GPU rank. For example::
            before reduce_scatter:
                param.grad (GPU #0): [1, 2, 3, 4]
                param.grad (GPU #1): [5, 6, 7, 8]
            after reduce_scatter:
                param.grad (GPU #0): [6, 8]    # 1+5, 2+6
                param.grad (GPU #1): [10, 12]  # 3+7, 4+8
        The local GPU's ``optim.step`` is responsible for updating a single
        shard of params, also corresponding to the current GPU's rank. This
        alignment is created by :func:`_shard_parameters`, which ensures that
        the local optimizer only sees the relevant parameter shard.
        """
        # First hook callback will see PRE state. If we have multiple params,
        # then subsequent hook callbacks will see POST state.
        self._assert_state([TrainingState_.BACKWARD_PRE, TrainingState_.BACKWARD_POST])
        self.training_state = TrainingState_.BACKWARD_POST
        if param.grad is None:
            return

        if param.grad.requires_grad:
            raise RuntimeError(
                "FSDP only works with gradients that don't require gradients"
            )

        self._free_full_params([param])
        # Switch to local shard after backward.
        self._use_param_local_shard([param])

        # Wait for all work in the current stream to finish, then start the
        # reductions in post_backward stream.
        self._streams["post_backward"].wait_stream(torch.cuda.current_stream())
        with torch.cuda.stream(self._streams["post_backward"]):
            orig_grad_data = param.grad.data

            if self.gradient_predivide_factor > 1:
                # Average grad by world_size for consistency with PyTorch DDP.
                param.grad.div_(self.gradient_predivide_factor)

            if param._is_sharded:  # type: ignore[attr-defined]
                grad_flatten = torch.flatten(param.grad)
                chunks = list(grad_flatten.chunk(self.world_size))
                num_pad = self.world_size * chunks[0].numel() - param.grad.numel()
                input_flattened = F.pad(grad_flatten, [0, num_pad])
                output = torch.zeros_like(chunks[0])
                dist._reduce_scatter_base(
                    output, input_flattened, group=self.process_group
                )
                if self.gradient_postdivide_factor > 1:
                    # Average grad by world_size for consistency with PyTorch DDP.
                    output.div_(self.gradient_postdivide_factor)
                param.grad.data = output
            else:
                # Currently the only way for _is_sharded to be False is if
                # world_size == 1. This could be relaxed in the future, e.g,
                # no sharding like PyTorch DDP, in which case grads should be
                # all-reduced here.
                assert (
                    self.world_size == 1
                ), "Currently the only way for _is_sharded to be False is \
                    world_size == 1"

            # After _post_backward_hook returns, orig_grad_data will eventually
            # go out of scope, at which point it could otherwise be freed for
            # further reuse by the main stream while the div/reduce_scatter/copy
            # are underway in the post_backward stream. See:
            # github.com/NVIDIA/apex/blob/master/apex/parallel/distributed.py
            orig_grad_data.record_stream(self._streams["post_backward"])