def __init__( self, group: dist.ProcessGroup, wrap_fsdp: bool, cuda_init_mode: CUDAInitMode, deterministic: bool, **fsdp_kwargs, ): super().__init__() self.rank = group.rank() self.world_size = group.size() move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE def _maybe_wrap(layer): if wrap_fsdp: return FSDP(layer, group, **fsdp_kwargs) return layer if deterministic: torch.manual_seed(0) self.module = nn.Sequential( _maybe_cuda(nn.Linear(8, 4), move_to_cuda), _maybe_wrap( nn.Sequential( _maybe_wrap(_maybe_cuda(nn.Linear(4, 16), move_to_cuda)), _maybe_cuda(nn.Linear(16, 16), move_to_cuda), ), ), _maybe_wrap(_maybe_cuda(nn.Linear(16, 4), move_to_cuda)), _maybe_cuda(nn.Linear(4, 8), move_to_cuda), )
def fp16_compress_hook(process_group: dist.ProcessGroup, bucket: dist._GradBucket) -> torch.futures.Future: """ This DDP communication hook implements a simple gradient compression approach that converts ``GradBucket`` tensors whose type is assumed to be ``torch.float32`` to half-precision floating point format (``torch.float16``). It allreduces those ``float16`` gradient tensors. Once compressed gradient tensors are allreduced, its then callback called ``decompress`` converts the aggregated result back to ``float32`` and takes the mean. Example:: >>> ddp_model.register_comm_hook(process_group, fp16_compress_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = (process_group.size() if process_group is not None else dist.get_world_size()) compressed_tensor = bucket.get_tensors()[0].to(torch.float16) fut = dist.all_reduce(compressed_tensor, group=group_to_use, async_op=True).get_future() def decompress(fut): decompressed_tensor = bucket.get_tensors()[0] # Decompress in place to reduce the peak memory. # See: https://github.com/pytorch/pytorch/issues/45968 decompressed_tensor.copy_(fut.value()[0].div_(world_size)) return [decompressed_tensor] return fut.then(decompress)
def allreduce_hook(process_group: dist.ProcessGroup, bucket: dist._GradBucket) -> torch.futures.Future: """ This DDP communication hook just calls ``allreduce`` using ``GradBucket`` tensors. Once gradient tensors are aggregated across all workers, its ``then`` callback takes the mean and returns the result. If user registers this hook, DDP results is expected to be same as the case where no hook was registered. Hence, this won't change behavior of DDP and user can use this as a reference or modify this hook to log useful information or any other purposes while unaffecting DDP behavior. Example:: >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = (process_group.size() if process_group is not None else dist.get_world_size()) tensor = bucket.get_tensors()[0] fut = dist.all_reduce(tensor, group=group_to_use, async_op=True).get_future() def then_callback(fut): return [fut.value()[0].div_(world_size)] return fut.then(then_callback)
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: key = (tensor.dtype, tensor.device, group) if key not in self.buckets: # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) world_size = group.size() shard_size = self._get_shard_size(tensor.element_size(), world_size) data = tensor.new_zeros((world_size, shard_size)) self.buckets[key] = Bucket(data, group) return self.buckets[key]
def _get_bucket(self, tensor: Tensor, group: ProcessGroup) -> Bucket: # TODO (Min): the `group` used here in the key is the object hash, not the content # hash. That means if FSDP instances are initialized with different process groups, # even when the group members are in fact the same, we end up creating different # buckets here. key = (tensor.dtype, tensor.device, group) if key not in self.buckets: # buckets are divided into world_size pieces, bucket.data shaped (world_size, shard_size) world_size = group.size() shard_size = self._get_shard_size(tensor.element_size(), world_size) data = tensor.new_zeros((world_size, shard_size)) self.buckets[key] = Bucket(data, group) self.buckets[key].setup() return self.buckets[key]
def validate_process_group(device: torch.device, process_group: ProcessGroup) -> None: """Do a quick test in case user called FSDP without calling torch.cuda.set_device() correctly. This can easily happen in cpu_offload case where the model resides on the CPU. """ if not hasattr(process_group, "allgather"): # Likely a dummy pg for unit test, skip checking. return world_size = process_group.size() if "cuda" in str(device): input_tensor = torch.ones(1).to(device) output = list(torch.zeros(world_size).to(device).chunk(world_size)) dist.all_gather(output, input_tensor, group=process_group) assert torch.cat(output).sum() == float(world_size), ( f"found {torch.cat(output).sum()} devices in process group but " f"world_size={world_size}. Check torch.cuda.set_device is called properly" )
def _allgather_then_aggregate_hook( process_group: dist.ProcessGroup, bucket: dist._GradBucket) -> torch.futures.Future: """ Similar to ``allreduce_hook``, this hook first gathers ``GradBucket`` tensors and its ``then`` callback aggregates the gathered gradient tensors and takes mean. Instead of ``allreduce`` this hook uses ``allgather``. Note that with W workers, both the computation and communication time scale as O(W) for allgather compared to O(logW) for allreduce. Therefore, this hook is expected to be much slower than ``allreduce_hook`` although both essentially do the same thing with the gradients. .. warning :: This is for test and experiments. User is suggested to use a faster alternative called ``allreduce_hook`` that uses ``allreduce`` protocol instead of ``allgather`` protocol. Example:: >>> ddp_model.register_comm_hook(process_group, allreduce_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = (process_group.size() if process_group is not None else dist.get_world_size()) tensor = bucket.get_tensors()[0] fut = dist.all_gather( _get_allgather_out_list(tensor, world_size), tensor, group=group_to_use, async_op=True, ).get_future() def aggregate(fut): all_ranks_tensor = fut.value()[0] tensor = bucket.get_tensors()[0] for r, gathered_tensor in enumerate(all_ranks_tensor): if r != rank: tensor += gathered_tensor return [tensor.div_(world_size)] return fut.then(aggregate)
def __init__( self, group: dist.ProcessGroup, cuda_init_mode: CUDAInitMode, add_bn: bool, deterministic: bool, ): super().__init__() self.rank = group.rank() self.world_size = group.size() if deterministic: torch.manual_seed(0) d_vocab = 23 d_model = 16 self.embed_tokens = nn.Embedding(d_vocab, d_model) self.transformer = nn.Transformer( d_model=d_model, num_encoder_layers=2, num_decoder_layers=2, dim_feedforward=8, dropout=0.1, ) self.output_proj = nn.Linear(d_model, d_vocab) # share the embedding and output projection weights self.output_proj.weight = self.embed_tokens.weight self.register_buffer("vocab_bias", self.embed_tokens.weight.new_ones((d_model, ))) self.register_buffer( "long_buffer", torch.zeros_like(self.vocab_bias, dtype=torch.long), ) # type: ignore[arg-type] self.bs = 2 self.bn = torch.nn.BatchNorm1d( self.bs) if add_bn else torch.nn.Identity() if cuda_init_mode == CUDAInitMode.CUDA_BEFORE: self = self.cuda() if deterministic: self.eval()
def get_global_ranks_from_group(group: ProcessGroup) -> List[int]: return [_get_global_rank(group, r) for r in range(group.size())]
def reduce_scatter_async( self, input_list: List[Tensor], group: ProcessGroup, callback_fn: Optional[Callable] = None, ) -> None: """ Reduce-scatter a list of tensors asynchronously, so smaller reductions can be bucketed together. The given callback (``callback_fn``) will be called with the reduced result at some later time. Call ``flush()`` to force all queued ops and callbacks to be executed. Note that large inputs will be reduced immediately, and this function may also flush the relevant bucket to make room for ``input_list``. Args: input_list (List[Tensor]): list of tensors to reduce-scatter. List should contain ``group.size()`` tensors and each tensor should have identical shape, dtype and device. group (ProcessGroup): process group for reduction callback_fn (Callable, Optional): callback function to call after the reduction executes. Function will be called with a single argument corresponding to the reduced result. """ world_size = group.size() assert ( len(input_list) == world_size ), f"reduce_scatter received {len(input_list)} inputs, expected group.size() ({world_size})" first_input = input_list[0] first_input_size = first_input.numel() bucket_shard_size = self._get_shard_size(first_input.element_size(), world_size) if first_input_size > bucket_shard_size: # TODO: investigate how to avoid using torch.cat (because it seems to be slow for CPU tensors) # input is too big to fit in the bucket, reduce-scatter directly output = torch.zeros_like(input_list[0]) if hasattr(dist, "_reduce_scatter_base"): input_flattened = torch.cat(input_list) dist._reduce_scatter_base(output, input_flattened, group=group) # type: ignore else: # fallback dist.reduce_scatter(output, input_list, group=group) if callback_fn is not None: callback_fn(output) return bucket = self._get_bucket(first_input, group) if first_input_size > bucket.data.size(1) - bucket.offset: # not enough space remaining in bucket, flush it now bucket.flush() # copy data from input_list into bucket stacked_input = torch.stack(input_list).view(world_size, first_input_size) offset = bucket.offset bucket.data[:, offset:offset + first_input_size].copy_(stacked_input) bucket.offset += first_input_size # callback will be given the reduced result if callback_fn is not None: result_view = bucket.output_shard[offset:offset + first_input_size].view_as( first_input) bucket.callbacks.append(functools.partial(callback_fn, result_view))
def powerSGD_hook( process_group: dist.ProcessGroup, bucket: dist._GradBucket, matrix_approximation_rank: int = 1, ) -> torch.futures.Future: """ This DDP communication hook implements a simplified PowerSGD gradient compression algorithm described in https://arxiv.org/abs/1905.13727. Once gradient tensors are aggregated across all workers, this hook applies compression as follows: 1) Views the input flattened 1D gradient tensor as a square-shaped tensor M with 0 paddings; 2) Decomposes M into two low-rank tensors P and Q, such that M = PQ^T, where Q is initialized from a standard normal distribution and orthogonalized; 2) Allreduces P; 3) Orthogonizes P; 4) Compute Q, which is approximately equal to M^TP; 5) Allreduces Q; 6) Computes M, which is approximately equal to PQ^T. 7) Truncates the input tensor to the original length. TODO(wayi@): The above procedure does two matmul+allreduce steps per iteration -- one left multiplication and one right multiplication. For warm start, can take one such step at a time, and alternate between them. Arguments: process_group (dist.ProcessGroup): Process group to communicate. bucket (dist._GradBucket): Bucket that stores a 1D flattened gradient tensor that batches multiple per-variable tensors. Note that since DDP comm hook only supports single process single device mode at this time, only exactly one tensor is stored in this bucket. matrix_approximation_rank (int): The low rank for matrix approximation. Typically only 1 or 2 is used. See https://arxiv.org/pdf/1905.13727.pdf. Returns: Future handler of the communication, which updates the gradients in place. Example:: PowerSGDState state(process_group, 1) >>> ddp_model.register_comm_hook(state, powerSGD_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD world_size = (process_group.size() if process_group is not None else dist.get_world_size()) # The input tensor is a flattened 1D tensor. input_tensor = bucket.get_tensors()[0] device = input_tensor.device total_length = input_tensor.shape[0] # View the input tensor as a 2D square-shape tensor, and pad 0s if necessary. square_side_length = math.ceil(math.sqrt(total_length)) padded_total_length = square_side_length**2 input_tensor.resize_(padded_total_length) input_tensor[total_length:padded_total_length].fill_(0) matrix = input_tensor.view(square_side_length, square_side_length) def create_low_rank_tensor(fill_random_values): "Returns a low-rank 2D tensor of square_side_length * matrix_approximation_rank." if fill_random_values: with torch.random.fork_rng(devices=[device]): # The seed makes sure that the initial random values are the same across all the DDP replicas. # Such seed should differ at every step. # Currently use the length of input tensor as the seed, which should be mostly different. # TODO(wayi@): Should read the random seed from the state of this hook provided by the constructor. torch.manual_seed(total_length) return torch.randn(square_side_length, matrix_approximation_rank, device=device) else: return torch.empty(square_side_length, matrix_approximation_rank, device=device) p = create_low_rank_tensor(fill_random_values=False) q = create_low_rank_tensor(fill_random_values=True) _orthogonalize(q, 0) torch.matmul(matrix, q, out=p) allreduce_p_fut = dist.all_reduce(p, group=group_to_use, async_op=True).get_future() def compute_q(fut): p = fut.value()[0] _orthogonalize(p, 0) torch.matmul(matrix.t(), p, out=q) return [ dist.all_reduce(q, group=group_to_use, async_op=True).get_future().value()[0] ] def decompress(fut): q = fut.value()[0].div_(world_size) torch.matmul(p, q.t(), out=matrix) ret = input_tensor.resize_(total_length) return [ret] return allreduce_p_fut.then(compute_q).then(decompress)
def quantization_pertensor_hook( process_group: dist.ProcessGroup, bucket: dist._GradBucket) -> torch.futures.Future: """ Applies the ``torch.quantize_per_tensor`` logic to DDP using ``allgather`` protocol. Workers first allgather the scale and zero point of their own ``GradBucket`` prior to the quantization. After all workers have that information, the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's own gradient tensors, and uses ``allgather`` to communicate these accross all workers. The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes and aggregates each quantized gradient tensors locally and returns the mean. .. warning :: This is experimental, and uses ``allgather`` protocol which is considerably slower than ``allreduce`` protocol. It works only with flattened grads. Example:: >>> ddp_model.register_comm_hook(process_group, quantization_pertensor_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = (process_group.size() if process_group is not None else dist.get_world_size()) tensor = bucket.get_tensors()[0] myObserver = torch.quantization.MinMaxObserver().cuda(tensor.device) myObserver(tensor) s, z = myObserver.calculate_qparams() s_and_z = torch.FloatTensor([s, z]).cuda(tensor.device) all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) # First, allgather scale and zeros. fut = dist.all_gather(all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True).get_future() def quantize_and_allgather(fut): # Store scale and zeros accross all workers. all_ranks_s_and_z = fut.wait()[0] # All workers quantize their own ``GradBucket`` tensors. quantized_tensor = _quantize_per_tensor_cuda( tensor, all_ranks_s_and_z[rank][0], all_ranks_s_and_z[rank][1]) # Allgather quantized tensors. fut = dist.all_gather( _get_allgather_out_list(quantized_tensor, world_size), quantized_tensor, group=group_to_use, async_op=True, ).get_future() return fut.wait() def dequantize_and_aggregate(fut): all_ranks_quantized_tensor = fut.wait()[0] aggregated_dequantized_tensor = torch.zeros_like( all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32) # Using previously allgathered scales and zeros, dequantize gradient tensors # locally and then aggregate them. for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): aggregated_dequantized_tensor += _dequantize_per_tensor_cuda( quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]) return [aggregated_dequantized_tensor / world_size] return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)
def quantization_perchannel_hook(process_group: dist.ProcessGroup, bucket: dist._GradBucket, bucket_size=512) -> torch.futures.Future: """ Applies the ``torch.quantize_per_channel`` logic to DDP using ``allgather`` protocol. Compared to pertensor, the main motivation of perchannel is for considerably large tensors such as a tensor that contains 6 million elements quantizing per a bucket size of 512 (or 128) elements may significantly increase the resolution. It first splits ``GradBucket`` tensors into multiple chunks (channels) of ``bucket_size`` elements. Then, workers allgather the scales and zero points of their own ``GradBucket`` prior to the quantization. After all workers have that information, the first ``then`` callback called ``quantize_and_allgather`` quantizes worker's own gradient tensors, and uses ``allgather`` to communicate these accross all workers. The final ``then`` callback called ``dequantize_and_aggregate``, dequantizes, flattens, and aggregates each quantized gradient tensors locally and returns the mean. .. warning :: This is experimental, and uses ``allgather`` protocol which is considerably slower than ``allreduce`` protocol. It works only with flattened grads. Example:: >>> ddp_model.register_comm_hook(process_group, quantization_perchannel_hook) """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank( ) if process_group is not None else dist.get_rank() world_size = (process_group.size() if process_group is not None else dist.get_world_size()) tensor = bucket.get_tensors()[0] tensor_in_channels = (nn.functional.pad( input=tensor, pad=(0, bucket_size - len(tensor) % bucket_size), mode="constant", value=0, ).view(-1, bucket_size).cuda(tensor.device)) myPerChannelObserver = torch.quantization.PerChannelMinMaxObserver().cuda( tensor.device) myPerChannelObserver(tensor_in_channels) s_ch, z_ch = myPerChannelObserver.calculate_qparams() s_and_z = torch.stack((s_ch, z_ch)).cuda(tensor.device) all_ranks_s_and_z = _get_allgather_out_list(s_and_z, world_size) # First, allgather scale and zeros. fut = dist.all_gather(all_ranks_s_and_z, s_and_z, group=group_to_use, async_op=True).get_future() def quantize_and_allgather(fut): # Store scale and zeros accross all workers. all_ranks_s_and_z = fut.wait()[0] # All workers quantize their corresponding ``GradBucket`` tensors. quantized_tensor = _quantize_per_channel_cuda( tensor_in_channels, all_ranks_s_and_z[rank, 0, :], all_ranks_s_and_z[rank, 1, :], ) # Allgather quantized tensors. fut = dist.all_gather( _get_allgather_out_list(quantized_tensor, world_size), quantized_tensor, group=group_to_use, async_op=True, ).get_future() return fut.wait() def dequantize_and_aggregate(fut): all_ranks_quantized_tensor = fut.wait()[0] aggregated_dequantized_tensor = torch.zeros_like( all_ranks_quantized_tensor[0], device=tensor.device, dtype=torch.float32) # Using previously allgathered scales and zeros, dequantize gradient tensors # locally and then aggregate them. for r, quantized_tensor in enumerate(all_ranks_quantized_tensor): aggregated_dequantized_tensor += _dequantize_per_channel_cuda( quantized_tensor, all_ranks_s_and_z[r][0], all_ranks_s_and_z[r][1]) return [ torch.flatten(aggregated_dequantized_tensor).cuda( tensor.device)[:tensor.size()[0]] / world_size ] return fut.then(quantize_and_allgather).then(dequantize_and_aggregate)