def _collect_ddp_bucket_info( bucket: dist.GradBucket, zero: ZeroRedundancyOptimizer, rank: int, assigned_rank: int, ): r""" Collects :class:`DistributedDataParallel` gradient bucket information for the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping. Arguments: bucket (dist.GradBucket): the current gradient bucket. zero (ZeroRedundancyOptimizer): the calling process's :class:`ZeroRedundancyOptimizer` instance. rank (int): the calling process's rank. assigned_rank (int): the rank assigned to update the parameters corresponding to ``bucket``. """ overlap_info = zero._overlap_info bucket_index = bucket.index() bucket_params = bucket.parameters() assert len(bucket_params) > 0, "Bucket {bucket_index} is empty" params_per_rank = overlap_info.params_per_rank params_per_bucket = overlap_info.params_per_bucket # Collect relevant information if assigned_rank == rank: overlap_info.offsets[bucket_index] = len( params_per_rank[assigned_rank]) params_per_rank[assigned_rank].extend(bucket_params) params_per_bucket.append(bucket_params)
def _save_ddp_bucket_info( bucket: dist.GradBucket, zero: ZeroRedundancyOptimizer, ): r""" Saves :class:`DistributedDataParallel` gradient bucket information for the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping. In particular, this function is meant to be called upon seeing each gradient bucket, meaning it does not save or compute any global information. Arguments: bucket (dist.GradBucket): the current gradient bucket. zero (ZeroRedundancyOptimizer): the calling process's :class:`ZeroRedundancyOptimizer` instance. """ overlap_info = zero._overlap_info bucket_params = bucket.parameters() assert len(bucket_params) > 0, "Empty bucket" # Save the parameters in the bucket overlap_info.params_per_bucket.append(bucket_params) if overlap_info.shard_buckets: # Additionally save the bucket size for the assignment heuristic to use bucket_size = 0 for param in bucket_params: bucket_size += param.numel() assert overlap_info.total_size is not None overlap_info.total_size += bucket_size
def hook_with_zero_fn( state: Any, bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: r""" Returns a :class:`Future` that gives a gradient bucket tensor and performs the equivalent of a :class:`ZeroRedundancyOptimizer` :meth:`step` if ``bucket`` is the last gradient bucket. The function performs additional computation on the iteration that the :class:`DistributedDataParallel` buckets are rebuilt to collect information used to implement the modified hook. Arguments: state (Any): any state for the hook. bucket (dist.GradBucket): the :class:`DistributedDataParallel` gradient bucket. """ fut = hook(state, bucket) overlap_info = zero._overlap_info bucket_index = bucket.index() # Proceed as normal until the DDP buckets have been rebuilt if not ddp._has_rebuilt_buckets: assert overlap_info.status == _OverlapStatus.UNINITIALIZED return fut if overlap_info.status == _OverlapStatus.UNINITIALIZED: overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS rank = zero.global_rank rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index) # Once DDP buckets have been rebuilt but ZeRO has not been # properly initialized yet, collect the information needed if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS: bucket_params = bucket.parameters() assert len(bucket_params) > 0, "Empty bucket" params_per_rank = overlap_info.params_per_rank params_per_bucket = overlap_info.params_per_bucket if rank_to_update == rank: overlap_info.offsets[bucket_index] = len( params_per_rank[rank_to_update]) params_per_rank[rank_to_update].extend(bucket_params) params_per_bucket.append(bucket_params) return fut assert overlap_info.status == _OverlapStatus.INITIALIZED # Save the bucket reference and all-reduce future for the final bucket if rank_to_update == rank: overlap_info.bucket_index_to_bucket[bucket_index] = bucket overlap_info.bucket_index_to_future[bucket_index] = fut # NOTE: The implementation from this point forward assumes that the # buckets are indexed incrementally starting from 0 in the order of # their autograd hooks firing num_buckets = len(overlap_info.params_per_bucket) is_last_bucket = bucket_index == num_buckets - 1 if not is_last_bucket: return fut # Perform partial optimizer step on all buckets after the final # bucket has been computed # NOTE: This should not be chained as a callback to the last bucket's # all-reduce future since that would add synchronization that delays # all optimizer computation to wait for that last all-reduce for bucket_index in range(num_buckets): rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index) num_local_optim_params = len(zero.optim.param_groups[0]["params"]) if rank_to_update == rank: gradients: List[Optional[torch.Tensor]] = \ [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)] assert bucket_index in overlap_info.offsets, \ f"Bucket index {bucket_index} was not assigned to rank {rank}" offset = overlap_info.offsets[bucket_index] # Ensure that the all-reduce completes before performing the # the parameter update assert bucket_index in overlap_info.bucket_index_to_future, \ f"All-reduce future for bucket {bucket_index} not saved " \ f"on rank {rank}" allreduce_future = overlap_info.bucket_index_to_future[ bucket_index] allreduce_future.wait() bucket_gradients = overlap_info.bucket_index_to_bucket[ bucket_index].gradients() for i, grad in enumerate(bucket_gradients): gradients[offset + i] = grad zero._local_step(gradients) device = overlap_info.params_per_bucket[bucket_index][0].device device_index = zero._device_to_device_index[device] assert bucket_index in zero._buckets[device_index][rank_to_update] overlap_info.broadcast_handles.append( dist.broadcast( zero._buckets[device_index][rank_to_update][bucket_index], src=rank_to_update, async_op=True)) # Ensure that all parameter updates are finished before the # next forward pass assert len(overlap_info.broadcast_handles) == num_buckets, \ f"Missing at least one broadcast handle on rank {rank}" _ = list(map(lambda x: x.wait(), overlap_info.broadcast_handles)) overlap_info.broadcast_handles.clear() # Reset per-iteration information overlap_info.bucket_index_to_future.clear() overlap_info.bucket_index_to_bucket.clear() return fut