Beispiel #1
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        def _maybe_wrap(layer):
            if wrap_fsdp:
                return FSDP(layer, group, **fsdp_kwargs)
            return layer

        if deterministic:
            torch.manual_seed(0)
        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(8, 4), move_to_cuda),
            _maybe_wrap(
                nn.Sequential(
                    _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)),
                    _maybe_cuda(nn.Linear(16, 16), move_to_cuda),
                ), ),
            _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)),
            _maybe_cuda(nn.Linear(4, 8), move_to_cuda),
        )
Beispiel #2
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        wrap_fsdp: bool,
        cuda_init_mode: CUDAInitMode,
        delay_before_free_ms: int,
        deterministic: bool,
        **fsdp_kwargs,
    ):
        super().__init__(
            group=group,
            wrap_fsdp=wrap_fsdp,
            cuda_init_mode=cuda_init_mode,
            deterministic=deterministic,
        )
        self.group = group
        self.delay_before_free_ms = delay_before_free_ms
        self.wrap_fsdp = wrap_fsdp
        self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
        if deterministic:
            # Give each rank different expert parameters
            torch.manual_seed(42 + self.rank)
        d_expert = 23
        d_shared = 12
        d_input = 8
        expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda)

        self.num_expert_params = sum([p.numel() for p in expert.parameters()])
        for p in expert.parameters():
            p.expert = True  # type: ignore[attr-defined]

        if deterministic:
            # Keep all other parameters the same across ranks
            torch.manual_seed(0)

        shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda)

        if wrap_fsdp:
            # we create a process group of size 1 for the expert params
            expert_group = torch.distributed.new_group(
                [group.rank()])  # world size 1 means no shard
            expert = FSDP(expert, expert_group,
                          **fsdp_kwargs)  # type: ignore[assignment]
            shared = FSDP(shared, group,
                          **fsdp_kwargs)  # type: ignore[assignment]

        self.module = nn.Sequential(
            _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda),
            shared, expert,
            _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda))
Beispiel #3
0
def _allgather_then_aggregate_hook(
        process_group: dist.ProcessGroup,
        bucket: dist._GradBucket) -> torch.futures.Future:
    """
    Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors
    and its ``then`` callback aggregates the gathered gradient tensors and takes
    mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with
    W workers, both the computation and communication time scale as O(W) for
    allgather compared to O(logW) for allreduce. Therefore, this hook is expected
    to be much slower than ``allreduce_hook`` although both essentially do the
    same thing with the gradients.

    .. warning ::
        This is for test and experiments. User is suggested to use a faster
        alternative called ``allreduce_hook``  that uses ``allreduce`` protocol
        instead of ``allgather`` protocol.

    Example::
        >>> ddp_model.register_comm_hook(process_group, allreduce_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = (process_group.size()
                  if process_group is not None else dist.get_world_size())

    tensor = bucket.get_tensors()[0]
    fut = dist.all_gather(
        _get_allgather_out_list(tensor, world_size),
        tensor,
        group=group_to_use,
        async_op=True,
    ).get_future()

    def aggregate(fut):
        all_ranks_tensor = fut.value()[0]
        tensor = bucket.get_tensors()[0]
        for r, gathered_tensor in enumerate(all_ranks_tensor):
            if r != rank:
                tensor += gathered_tensor

        return [tensor.div_(world_size)]

    return fut.then(aggregate)
Beispiel #4
0
    def __init__(
        self,
        group: dist.ProcessGroup,
        cuda_init_mode: CUDAInitMode,
        add_bn: bool,
        deterministic: bool,
    ):
        super().__init__()
        self.rank = group.rank()
        self.world_size = group.size()
        if deterministic:
            torch.manual_seed(0)
        d_vocab = 23
        d_model = 16

        self.embed_tokens = nn.Embedding(d_vocab, d_model)
        self.transformer = nn.Transformer(
            d_model=d_model,
            num_encoder_layers=2,
            num_decoder_layers=2,
            dim_feedforward=8,
            dropout=0.1,
        )
        self.output_proj = nn.Linear(d_model, d_vocab)

        # share the embedding and output projection weights
        self.output_proj.weight = self.embed_tokens.weight
        self.register_buffer("vocab_bias",
                             self.embed_tokens.weight.new_ones((d_model, )))
        self.register_buffer(
            "long_buffer",
            torch.zeros_like(self.vocab_bias, dtype=torch.long),
        )  # type: ignore[arg-type]

        self.bs = 2
        self.bn = torch.nn.BatchNorm1d(
            self.bs) if add_bn else torch.nn.Identity()
        if cuda_init_mode == CUDAInitMode.CUDA_BEFORE:
            self = self.cuda()
        if deterministic:
            self.eval()
def quantization_pertensor_hook(
        process_group: dist.ProcessGroup,
        bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
    """
    Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather``
    protocol. Workers first allgather the scale and zero point of their own
    ``GradBucket`` prior to the quantization. After all workers have that information,
    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
    own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and
    aggregates each quantized gradient tensor locally and returns the mean.

    .. warning ::
        This is experimental, and uses ``allgather`` protocol which is considerably slower than
        ``allreduce`` protocol. It works only with flattened grads.

    Example::
        >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = group_to_use.size()

    tensor = bucket.buffer()

    myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device)
    myObserver(tensor)

    s, z = myObserver.calculate_qparams()
    s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device)

    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)

    # First, allgather scale and zeros.
    fut = dist.all_gather(all_ranks_s_and_z,
                          s_and_z,
                          group=group_to_use,
                          async_op=True).get_future()

    def quantize_and_allgather(fut):
        # Store scale and zeros accross all workers.
        all_ranks_s_and_z = fut.wait()[0]
        # All workers quantize their own ``GradBucket`` tensors.
        quantized_tensor = _quantize_per_tensor_cuda(
            tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1])
        # Allgather quantized tensors.
        fut = dist.all_gather(
            _get_allgather_out_list(quantized_tensor, world_size),
            quantized_tensor,
            group=group_to_use,
            async_op=True,
        ).get_future()

        return fut.wait()

    def dequantize_and_aggregate(fut):
        all_ranks_quantized_tensor = fut.wait()[0]

        aggregated_dequantized_tensor = torch.zeros_like(
            all_ranks_quantized_tensor[0],
            device=tensor.device,
            dtype=torch.float32)
        # Using previously allgathered scales and zeros, dequantize gradient tensors
        # locally and then aggregate them.
        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
            aggregated_dequantized_tensor += _dequantize_per_tensor_cuda(
                quantized_tensor, all_ranks_s_and_z[r][0],
                all_ranks_s_and_z[r][1])

        return aggregated_dequantized_tensor / world_size

    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
def quantization_perchannel_hook(
        process_group: dist.ProcessGroup,
        bucket: dist.GradBucket,
        bucket_size=512) -> torch.futures.Future[torch.Tensor]:
    """
    Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather``
    protocol. Compared to pertensor, the main motivation of perchannel is
    for considerably large tensors such as a tensor that contains 6 million
    elements quantizing per a bucket size of 512 (or 128) elements may significantly
    increase the resolution.

    It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size``
    elements. Then, workers allgather the scales and zero points of their own
    ``GradBucket`` prior to the quantization. After all workers have that information,
    the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's
    own gradient tensor, and uses ``allgather`` to communicate these accross all workers.
    The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and
    aggregates each quantized gradient tensor locally and returns the mean.

    .. warning ::
        This is experimental, and uses ``allgather`` protocol which is considerably slower than
        ``allreduce`` protocol. It works only with flattened grads.

    Example::
        >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook)
    """
    group_to_use = process_group if process_group is not None else dist.group.WORLD
    rank = process_group.rank(
    ) if process_group is not None else dist.get_rank()
    world_size = group_to_use.size()

    tensor = bucket.buffer()

    tensor_in_channels = (nn.functional.pad(
        input=tensor,
        pad=(0, bucket_size - len(tensor) % bucket_size),
        mode="constant",
        value=0,
    ).view(-1, bucket_size).cuda(tensor.device))

    myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda(
        tensor.device)
    myPerChannelObserver(tensor_in_channels)

    s_ch, z_ch = myPerChannelObserver.calculate_qparams()
    s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device)

    all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size)
    # First, allgather scale and zeros.
    fut = dist.all_gather(all_ranks_s_and_z,
                          s_and_z,
                          group=group_to_use,
                          async_op=True).get_future()

    def quantize_and_allgather(fut):
        # Store scale and zeros accross all workers.
        all_ranks_s_and_z = fut.wait()[0]
        # All workers quantize their corresponding ``GradBucket`` tensors.
        quantized_tensor = _quantize_per_channel_cuda(
            tensor_in_channels,
            all_ranks_s_and_z[rank, 0, :],
            all_ranks_s_and_z[rank, 1, :],
        )
        # Allgather quantized tensors.
        fut = dist.all_gather(
            _get_allgather_out_list(quantized_tensor, world_size),
            quantized_tensor,
            group=group_to_use,
            async_op=True,
        ).get_future()

        return fut.wait()

    def dequantize_and_aggregate(fut):
        all_ranks_quantized_tensor = fut.wait()[0]

        aggregated_dequantized_tensor = torch.zeros_like(
            all_ranks_quantized_tensor[0],
            device=tensor.device,
            dtype=torch.float32)
        # Using previously allgathered scales and zeros, dequantize gradient tensors
        # locally and then aggregate them.
        for r, quantized_tensor in enumerate(all_ranks_quantized_tensor):
            aggregated_dequantized_tensor += _dequantize_per_channel_cuda(
                quantized_tensor, all_ranks_s_and_z[r][0],
                all_ranks_s_and_z[r][1])

        return (torch.flatten(aggregated_dequantized_tensor).cuda(
            tensor.device)[:tensor.size()[0]] / world_size)

    return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)