def __init__( self, group: dist.ProcessGroup, wrap_fsdp: bool, cuda_init_mode: CUDAInitMode, deterministic: bool, **fsdp_kwargs, ): super().__init__() self.rank = group.rank() self.world_size = group.size() move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE def _maybe_wrap(layer): if wrap_fsdp: return FSDP(layer, group, **fsdp_kwargs) return layer if deterministic: torch.manual_seed(0) self.module = nn.Sequential( _maybe_cuda(nn.Linear(8, 4), move_to_cuda), _maybe_wrap( nn.Sequential( _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), _maybe_cuda(nn.Linear(16, 16), move_to_cuda), ), ), _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), _maybe_cuda(nn.Linear(4, 8), move_to_cuda), )
def __init__( self, group: dist.ProcessGroup, wrap_fsdp: bool, cuda_init_mode: CUDAInitMode, delay_before_free_ms: int, deterministic: bool, **fsdp_kwargs, ): super().__init__( group=group, wrap_fsdp=wrap_fsdp, cuda_init_mode=cuda_init_mode, deterministic=deterministic, ) self.group = group self.delay_before_free_ms = delay_before_free_ms self.wrap_fsdp = wrap_fsdp self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE if deterministic: # Give each rank different expert parameters torch.manual_seed(42 + self.rank) d_expert = 23 d_shared = 12 d_input = 8 expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): p.expert = True # type: ignore[attr-defined] if deterministic: # Keep all other parameters the same across ranks torch.manual_seed(0) shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda) if wrap_fsdp: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group( [group.rank()]) # world size 1 means no shard expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment] shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment] self.module = nn.Sequential( _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda), shared, expert, _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda))
def _allgather_then_aggregate_hook( process_group: dist.ProcessGroup, bucket: dist._GradBucket) -> torch.futures.Future: """ Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors and its ``then`` callback aggregates the gathered gradient tensors and takes mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with W workers, both the computation and communication time scale as O(W) for allgather compared to O(logW) for allreduce. Therefore, this hook is expected to be much slower than ``allreduce_hook`` although both essentially do the same thing with the gradients. .. warning :: This is for test and experiments. User is suggested to use a faster alternative called ``allreduce_hook`` that uses ``allreduce`` protocol instead of ``allgather`` protocol. Example:: >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = (process_group.size() if process_group is not None else dist.get_world_size()) tensor = bucket.get_tensors()[0] fut = dist.all_gather( _get_allgather_out_list(tensor, world_size), tensor, group=group_to_use, async_op=True, ).get_future() def aggregate(fut): all_ranks_tensor = fut.value()[0] tensor = bucket.get_tensors()[0] for r, gathered_tensor in enumerate(all_ranks_tensor): if r != rank: tensor += gathered_tensor return [tensor.div_(world_size)] return fut.then(aggregate)
def __init__( self, group: dist.ProcessGroup, cuda_init_mode: CUDAInitMode, add_bn: bool, deterministic: bool, ): super().__init__() self.rank = group.rank() self.world_size = group.size() if deterministic: torch.manual_seed(0) d_vocab = 23 d_model = 16 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, ) self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights self.output_proj.weight = self.embed_tokens.weight self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model, ))) self.register_buffer( "long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long), ) # type: ignore[arg-type] self.bs = 2 self.bn = torch.nn.BatchNorm1d( self.bs) if add_bn else torch.nn.Identity() if cuda_init_mode == CUDAInitMode.CUDA_BEFORE: self = self.cuda() if deterministic: self.eval()
def quantization_pertensor_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]: """ Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. Workers first allgather the scale and zero point of their own ``GradBucket`` prior to the quantization. After all workers have that information, the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's own gradient tensor, and uses ``allgather`` to communicate these accross all workers. The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and aggregates each quantized gradient tensor locally and returns the mean. .. warning :: This is experimental, and uses ``allgather`` protocol which is considerably slower than ``allreduce`` protocol. It works only with flattened grads. Example:: >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = group_to_use.size() tensor = bucket.buffer() myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device) myObserver(tensor) s, z = myObserver.calculate_qparams() s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device) all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) # First, allgather scale and zeros. fut = dist.all_gather(all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True).get_future() def quantize_and_allgather(fut): # Store scale and zeros accross all workers. all_ranks_s_and_z = fut.wait()[0] # All workers quantize their own ``GradBucket`` tensors. quantized_tensor = _quantize_per_tensor_cuda( tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]) # Allgather quantized tensors. fut = dist.all_gather( _get_allgather_out_list(quantized_tensor, world_size), quantized_tensor, group=group_to_use, async_op=True, ).get_future() return fut.wait() def dequantize_and_aggregate(fut): all_ranks_quantized_tensor = fut.wait()[0] aggregated_dequantized_tensor = torch.zeros_like( all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32) # Using previously allgathered scales and zeros, dequantize gradient tensors # locally and then aggregate them. for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): aggregated_dequantized_tensor += _dequantize_per_tensor_cuda( quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]) return aggregated_dequantized_tensor / world_size return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
def quantization_perchannel_hook( process_group: dist.ProcessGroup, bucket: dist.GradBucket, bucket_size=512) -> torch.futures.Future[torch.Tensor]: """ Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol. Compared to pertensor, the main motivation of perchannel is for considerably large tensors such as a tensor that contains 6 million elements quantizing per a bucket size of 512 (or 128) elements may significantly increase the resolution. It first splits ``GradBucket`` tensor into multiple chunks (channels) of ``bucket_size`` elements. Then, workers allgather the scales and zero points of their own ``GradBucket`` prior to the quantization. After all workers have that information, the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's own gradient tensor, and uses ``allgather`` to communicate these accross all workers. The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and aggregates each quantized gradient tensor locally and returns the mean. .. warning :: This is experimental, and uses ``allgather`` protocol which is considerably slower than ``allreduce`` protocol. It works only with flattened grads. Example:: >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = group_to_use.size() tensor = bucket.buffer() tensor_in_channels = (nn.functional.pad( input=tensor, pad=(0, bucket_size - len(tensor) % bucket_size), mode="constant", value=0, ).view(-1, bucket_size).cuda(tensor.device)) myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda( tensor.device) myPerChannelObserver(tensor_in_channels) s_ch, z_ch = myPerChannelObserver.calculate_qparams() s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device) all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) # First, allgather scale and zeros. fut = dist.all_gather(all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True).get_future() def quantize_and_allgather(fut): # Store scale and zeros accross all workers. all_ranks_s_and_z = fut.wait()[0] # All workers quantize their corresponding ``GradBucket`` tensors. quantized_tensor = _quantize_per_channel_cuda( tensor_in_channels, all_ranks_s_and_z[rank, 0, :], all_ranks_s_and_z[rank, 1, :], ) # Allgather quantized tensors. fut = dist.all_gather( _get_allgather_out_list(quantized_tensor, world_size), quantized_tensor, group=group_to_use, async_op=True, ).get_future() return fut.wait() def dequantize_and_aggregate(fut): all_ranks_quantized_tensor = fut.wait()[0] aggregated_dequantized_tensor = torch.zeros_like( all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32) # Using previously allgathered scales and zeros, dequantize gradient tensors # locally and then aggregate them. for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): aggregated_dequantized_tensor += _dequantize_per_channel_cuda( quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]) return (torch.flatten(aggregated_dequantized_tensor).cuda( tensor.device)[:tensor.size()[0]] / world_size) return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)