Пример #1
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        def _maybe_wrap(layer):
            if wrap_fsdp:
                return FSDP(layer, group, **fsdp_kwargs)
            return layer

        if deterministic:
            torch.manual_seed(0)
        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
                    _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
                ), ),
            _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
            _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
        )
Пример #2
0
def allreduce_hook(process_group: dist.ProcessGroup,
                   bucket: dist._GradBucket) -> torch.futures.Future:
    """
    This DDP communication hook just calls ``allreduce`` using ``GradBucket``
    tensors. Once gradient tensors are aggregated across all workers, its ``then``
    callback takes the mean and returns the result. If user registers this hook,
    DDP results is expected to be same as the case where no hook was registered.
    Hence, this won't change behavior of DDP and user can use this as a reference
    or modify this hook to log useful information or any other purposes while
    unaffecting DDP behavior.

    Example::
        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = (process_group.size()
                  if process_group is not None else dist.get_world_size())

    tensor = bucket.get_tensors()[0]
    fut = dist.all_reduce(tensor, group=group_to_use,
                          async_op=True).get_future()

    def then_callback(fut):
        return [fut.value()[0].div_(world_size)]

    return fut.then(then_callback)
Пример #3
0
def fp16_compress_hook(process_group: dist.ProcessGroup,
                       bucket: dist._GradBucket) -> torch.futures.Future:
    """
    This DDP communication hook implements a simple gradient compression
    approach that converts ``GradBucket`` tensors whose type is assumed to be
    ``torch.float32`` to half-precision floating point format (``torch.float16``).
    It allreduces those ``float16`` gradient tensors. Once compressed gradient
    tensors are allreduced, its then callback called ``decompress`` converts the
    aggregated result back to ``float32`` and takes the mean.

    Example::
        >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = (process_group.size()
                  if process_group is not None else dist.get_world_size())

    compressed_tensor = bucket.get_tensors()[0].to(torch.float16)

    fut = dist.all_reduce(compressed_tensor, group=group_to_use,
                          async_op=True).get_future()

    def decompress(fut):
        decompressed_tensor = bucket.get_tensors()[0]
        # Decompress in place to reduce the peak memory.
        # See: https://github.com/pytorch/pytorch/issues/45968
        decompressed_tensor.copy_(fut.value()[0].div_(world_size))
        return [decompressed_tensor]

    return fut.then(decompress)
Пример #4
0
def _allgather_then_aggregate_hook(
        process_group: dist.ProcessGroup,
        bucket: dist._GradBucket) -> torch.futures.Future:
    """
    Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
    and its ``then`` callback aggregates the gathered gradient tensors and takes
    mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with
    W workers, both the computation and communication time scale as O(W) for
    allgather compared to O(logW) for allreduce. Therefore, this hook is expected
    to be much slower than ``allreduce_hook`` although both essentially do the
    same thing with the gradients.

    .. warning ::
        This is for test and experiments. User is suggested to use a faster
        alternative called ``allreduce_hook``  that uses ``allreduce`` protocol
        instead of ``allgather`` protocol.

    Example::
        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = (process_group.size()
                  if process_group is not None else dist.get_world_size())

    tensor = bucket.get_tensors()[0]
    fut = dist.all_gather(
        _get_allgather_out_list(tensor, world_size),
        tensor,
        group=group_to_use,
        async_op=True,
    ).get_future()

    def aggregate(fut):
        all_ranks_tensor = fut.value()[0]
        tensor = bucket.get_tensors()[0]
        for r, gathered_tensor in enumerate(all_ranks_tensor):
            if r != rank:
                tensor += gathered_tensor

        return [tensor.div_(world_size)]

    return fut.then(aggregate)
Пример #5
0
 def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
     key = (tensor.dtype, tensor.device, group)
     if key not in self.buckets:
         # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
         world_size = group.size()
         shard_size = self._get_shard_size(tensor.element_size(),
                                           world_size)
         data = tensor.new_zeros((world_size, shard_size))
         self.buckets[key] = Bucket(data, group)
     return self.buckets[key]
Пример #6
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        cuda_init_mode: CUDAInitMode,
        add_bn: bool,
        deterministic: bool,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        if deterministic:
            torch.manual_seed(0)
        d_vocab = 23
        d_model = 16

        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer("vocab_bias",
                             self.embed_tokens.weight.new_ones((d_model, )))
        self.register_buffer(
            "long_buffer",
            torch.zeros_like(self.vocab_bias, dtype=torch.long),
        )  # type: ignore[arg-type]

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(
            self.bs) if add_bn else torch.nn.Identity()
        if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
            self = self.cuda()
        if deterministic:
            self.eval()
Пример #7
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        delay_before_free_ms: int,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__(
            group=group,
            wrap_fsdp=wrap_fsdp,
            cuda_init_mode=cuda_init_mode,
            deterministic=deterministic,
        )
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp
        self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
        if deterministic:
            # Give each rank different expert parameters
            torch.manual_seed(42 + self.rank)
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        if deterministic:
            # Keep all other parameters the same across ranks
            torch.manual_seed(0)

        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FSDP(expert, expert_group,
                          **fsdp_kwargs)  # type: ignore[assignment]
            shared = FSDP(shared, group,
                          **fsdp_kwargs)  # type: ignore[assignment]

        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
            shared, expert,
            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda))
 def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket:
     # TODO (Min): the `group` used here in the key is the object hash, not the content
     #     hash. That means if FSDP instances are initialized with different process groups,
     #     even when the group members are in fact the same, we end up creating different
     #     buckets here.
     key = (tensor.dtype, tensor.device, group)
     if key not in self.buckets:
         # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size)
         world_size = group.size()
         shard_size = self._get_shard_size(tensor.element_size(),
                                           world_size)
         data = tensor.new_zeros((world_size, shard_size))
         self.buckets[key] = Bucket(data, group)
     self.buckets[key].setup()
     return self.buckets[key]
Пример #9
0
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None:
    """Do a quick test in case user called FSDP without calling torch.cuda.set_device()
    correctly. This can easily happen in cpu_offload case where the model resides on
    the CPU.
    """
    if not hasattr(process_group, "allgather"):
        # Likely a dummy pg for unit test, skip checking.
        return

    world_size = process_group.size()
    if "cuda" in str(device):
        input_tensor = torch.ones(1).to(device)
        output = list(torch.zeros(world_size).to(device).chunk(world_size))
        dist.all_gather(output, input_tensor, group=process_group)
        assert torch.cat(output).sum() == float(world_size), (
            f"found {torch.cat(output).sum()} devices in process group but "
            f"world_size={world_size}. Check torch.cuda.set_device is called properly"
        )
Пример #10
0
def quantization_pertensor_hook(
        process_group: dist.ProcessGroup,
        bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
    """
    Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
    protocol. Workers first allgather the scale and zero point of their own
    ``GradBucket`` prior to the quantization. After all workers have that information,
    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
    own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
    aggregates each quantized gradient tensor locally and returns the mean.

    .. warning ::
        This is experimental, and uses ``allgather`` protocol which is considerably slower than
        ``allreduce`` protocol. It works only with flattened grads.

    Example::
        >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = group_to_use.size()

    tensor = bucket.buffer()

    myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device)
    myObserver(tensor)

    s, z = myObserver.calculate_qparams()
    s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)

    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)

    # First, allgather scale and zeros.
    fut = dist.all_gather(all_ranks_s_and_z,
                          s_and_z,
                          group=group_to_use,
                          async_op=True).get_future()

    def quantize_and_allgather(fut):
        # Store scale and zeros accross all workers.
        all_ranks_s_and_z = fut.wait()[0]
        # All workers quantize their own ``GradBucket`` tensors.
        quantized_tensor = _quantize_per_tensor_cuda(
            tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1])
        # Allgather quantized tensors.
        fut = dist.all_gather(
            _get_allgather_out_list(quantized_tensor, world_size),
            quantized_tensor,
            group=group_to_use,
            async_op=True,
        ).get_future()

        return fut.wait()

    def dequantize_and_aggregate(fut):
        all_ranks_quantized_tensor = fut.wait()[0]

        aggregated_dequantized_tensor = torch.zeros_like(
            all_ranks_quantized_tensor[0],
            device=tensor.device,
            dtype=torch.float32)
        # Using previously allgathered scales and zeros, dequantize gradient tensors
        # locally and then aggregate them.
        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
            aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
                quantized_tensor, all_ranks_s_and_z[r][0],
                all_ranks_s_and_z[r][1])

        return aggregated_dequantized_tensor / world_size

    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
Пример #11
0
def quantization_perchannel_hook(
        process_group: dist.ProcessGroup,
        bucket: dist.GradBucket,
        bucket_size=512) -> torch.futures.Future[torch.Tensor]:
    """
    Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
    protocol. Compared to pertensor, the main motivation of perchannel is
    for considerably large tensors such as a tensor that contains 6 million
    elements quantizing per a bucket size of 512 (or 128) elements may significantly
    increase the resolution.

    It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
    elements. Then, workers allgather the scales and zero points of their own
    ``GradBucket`` prior to the quantization. After all workers have that information,
    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
    own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
    aggregates each quantized gradient tensor locally and returns the mean.

    .. warning ::
        This is experimental, and uses ``allgather`` protocol which is considerably slower than
        ``allreduce`` protocol. It works only with flattened grads.

    Example::
        >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = group_to_use.size()

    tensor = bucket.buffer()

    tensor_in_channels = (nn.functional.pad(
        input=tensor,
        pad=(0, bucket_size - len(tensor) % bucket_size),
        mode="constant",
        value=0,
    ).view(-1, bucket_size).cuda(tensor.device))

    myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda(
        tensor.device)
    myPerChannelObserver(tensor_in_channels)

    s_ch, z_ch = myPerChannelObserver.calculate_qparams()
    s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)

    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
    # First, allgather scale and zeros.
    fut = dist.all_gather(all_ranks_s_and_z,
                          s_and_z,
                          group=group_to_use,
                          async_op=True).get_future()

    def quantize_and_allgather(fut):
        # Store scale and zeros accross all workers.
        all_ranks_s_and_z = fut.wait()[0]
        # All workers quantize their corresponding ``GradBucket`` tensors.
        quantized_tensor = _quantize_per_channel_cuda(
            tensor_in_channels,
            all_ranks_s_and_z[rank, 0, :],
            all_ranks_s_and_z[rank, 1, :],
        )
        # Allgather quantized tensors.
        fut = dist.all_gather(
            _get_allgather_out_list(quantized_tensor, world_size),
            quantized_tensor,
            group=group_to_use,
            async_op=True,
        ).get_future()

        return fut.wait()

    def dequantize_and_aggregate(fut):
        all_ranks_quantized_tensor = fut.wait()[0]

        aggregated_dequantized_tensor = torch.zeros_like(
            all_ranks_quantized_tensor[0],
            device=tensor.device,
            dtype=torch.float32)
        # Using previously allgathered scales and zeros, dequantize gradient tensors
        # locally and then aggregate them.
        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
            aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
                quantized_tensor, all_ranks_s_and_z[r][0],
                all_ranks_s_and_z[r][1])

        return (torch.flatten(aggregated_dequantized_tensor).cuda(
            tensor.device)[:tensor.size()[0]] / world_size)

    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
Пример #12
0
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]:
    return [_get_global_rank(group, r) for r in range(group.size())]
    def reduce_scatter_async(
        self,
        input_list: List[Tensor],
        group: ProcessGroup,
        callback_fn: Optional[Callable] = None,
    ) -> None:
        """
        Reduce-scatter a list of tensors asynchronously, so smaller reductions
        can be bucketed together. The given callback (``callback_fn``) will be
        called with the reduced result at some later time. Call ``flush()`` to
        force all queued ops and callbacks to be executed.

        Note that large inputs will be reduced immediately, and this function
        may also flush the relevant bucket to make room for ``input_list``.

        Args:
            input_list (List[Tensor]): list of tensors to reduce-scatter. List
                should contain ``group.size()`` tensors and each tensor should
                have identical shape, dtype and device.
            group (ProcessGroup): process group for reduction
            callback_fn (Callable, Optional): callback function to call after
                the reduction executes. Function will be called with a single
                argument corresponding to the reduced result.
        """
        world_size = group.size()

        assert (
            len(input_list) == world_size
        ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})"

        first_input = input_list[0]
        first_input_size = first_input.numel()

        bucket_shard_size = self._get_shard_size(first_input.element_size(),
                                                 world_size)
        if first_input_size > bucket_shard_size:
            # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors)
            # input is too big to fit in the bucket, reduce-scatter directly
            output = torch.zeros_like(input_list[0])
            if hasattr(dist, "_reduce_scatter_base"):
                input_flattened = torch.cat(input_list)
                dist._reduce_scatter_base(output, input_flattened,
                                          group=group)  # type: ignore
            else:
                # fallback
                dist.reduce_scatter(output, input_list, group=group)
            if callback_fn is not None:
                callback_fn(output)
            return

        bucket = self._get_bucket(first_input, group)
        if first_input_size > bucket.data.size(1) - bucket.offset:
            # not enough space remaining in bucket, flush it now
            bucket.flush()

        # copy data from input_list into bucket
        stacked_input = torch.stack(input_list).view(world_size,
                                                     first_input_size)
        offset = bucket.offset
        bucket.data[:, offset:offset + first_input_size].copy_(stacked_input)
        bucket.offset += first_input_size

        # callback will be given the reduced result
        if callback_fn is not None:
            result_view = bucket.output_shard[offset:offset +
                                              first_input_size].view_as(
                                                  first_input)
            bucket.callbacks.append(functools.partial(callback_fn,
                                                      result_view))
Пример #14
0
def powerSGD_hook(
    process_group: dist.ProcessGroup,
    bucket: dist._GradBucket,
    matrix_approximation_rank: int = 1,
) -> torch.futures.Future:
    """
    This DDP communication hook implements a simplified PowerSGD gradient compression
    algorithm described in https://arxiv.org/abs/1905.13727.
    Once gradient tensors are aggregated across all workers, this hook applies
    compression as follows:
    1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings;
    2) Decomposes M into two low-rank tensors P and Q,
    such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized;
    2) Allreduces P;
    3) Orthogonizes P;
    4) Compute Q, which is approximately equal to M^TP;
    5) Allreduces Q;
    6) Computes M, which is approximately equal to PQ^T.
    7) Truncates the input tensor to the original length.

    TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration --
    one left multiplication and one right multiplication.
    For warm start, can take one such step at a time, and alternate between them.

    Arguments:
        process_group (dist.ProcessGroup): Process group to communicate.
        bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors.
            Note that since DDP comm hook only supports single process single device mode at this time,
            only exactly one tensor is stored in this bucket.
        matrix_approximation_rank (int): The low rank for matrix approximation.
            Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf.

    Returns:
        Future handler of the communication, which updates the gradients in place.

    Example::
        PowerSGDState state(process_group, 1)
        >>> ddp_model.register_comm_hook(state, powerSGD_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    world_size = (process_group.size()
                  if process_group is not None else dist.get_world_size())

    # The input tensor is a flattened 1D tensor.
    input_tensor = bucket.get_tensors()[0]
    device = input_tensor.device
    total_length = input_tensor.shape[0]

    # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary.
    square_side_length = math.ceil(math.sqrt(total_length))
    padded_total_length = square_side_length**2
    input_tensor.resize_(padded_total_length)
    input_tensor[total_length:padded_total_length].fill_(0)
    matrix = input_tensor.view(square_side_length, square_side_length)

    def create_low_rank_tensor(fill_random_values):
        "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank."
        if fill_random_values:
            with torch.random.fork_rng(devices=[device]):
                # The seed makes sure that the initial random values are the same across all the DDP replicas.
                # Such seed should differ at every step.
                # Currently use the length of input tensor as the seed, which should be mostly different.
                # TODO(wayi@): Should read the random seed from the state of this hook provided by the constructor.
                torch.manual_seed(total_length)
                return torch.randn(square_side_length,
                                   matrix_approximation_rank,
                                   device=device)
        else:
            return torch.empty(square_side_length,
                               matrix_approximation_rank,
                               device=device)

    p = create_low_rank_tensor(fill_random_values=False)
    q = create_low_rank_tensor(fill_random_values=True)
    _orthogonalize(q, 0)

    torch.matmul(matrix, q, out=p)
    allreduce_p_fut = dist.all_reduce(p, group=group_to_use,
                                      async_op=True).get_future()

    def compute_q(fut):
        p = fut.value()[0]
        _orthogonalize(p, 0)

        torch.matmul(matrix.t(), p, out=q)

        return [
            dist.all_reduce(q, group=group_to_use,
                            async_op=True).get_future().value()[0]
        ]

    def decompress(fut):
        q = fut.value()[0].div_(world_size)
        torch.matmul(p, q.t(), out=matrix)

        ret = input_tensor.resize_(total_length)
        return [ret]

    return allreduce_p_fut.then(compute_q).then(decompress)