Exemple #1
0
def _collect_ddp_bucket_info(
    bucket: dist.GradBucket,
    zero: ZeroRedundancyOptimizer,
    rank: int,
    assigned_rank: int,
):
    r"""
    Collects :class:`DistributedDataParallel` gradient bucket information for
    the :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when
    overlapping.

    Arguments:
        bucket (dist.GradBucket): the current gradient bucket.
        zero (ZeroRedundancyOptimizer): the calling process's
            :class:`ZeroRedundancyOptimizer` instance.
        rank (int): the calling process's rank.
        assigned_rank (int): the rank assigned to update the parameters
            corresponding to ``bucket``.
    """
    overlap_info = zero._overlap_info
    bucket_index = bucket.index()
    bucket_params = bucket.parameters()
    assert len(bucket_params) > 0, "Bucket {bucket_index} is empty"
    params_per_rank = overlap_info.params_per_rank
    params_per_bucket = overlap_info.params_per_bucket

    # Collect relevant information
    if assigned_rank == rank:
        overlap_info.offsets[bucket_index] = len(
            params_per_rank[assigned_rank])
    params_per_rank[assigned_rank].extend(bucket_params)
    params_per_bucket.append(bucket_params)
Exemple #2
0
def _save_ddp_bucket_info(
    bucket: dist.GradBucket,
    zero: ZeroRedundancyOptimizer,
):
    r"""
    Saves :class:`DistributedDataParallel` gradient bucket information for the
    :class:`ZeroRedundancyOptimizer` instance ``zero`` to use when overlapping.
    In particular, this function is meant to be called upon seeing each
    gradient bucket, meaning it does not save or compute any global
    information.

    Arguments:
        bucket (dist.GradBucket): the current gradient bucket.
        zero (ZeroRedundancyOptimizer): the calling process's
            :class:`ZeroRedundancyOptimizer` instance.
    """
    overlap_info = zero._overlap_info
    bucket_params = bucket.parameters()
    assert len(bucket_params) > 0, "Empty bucket"

    # Save the parameters in the bucket
    overlap_info.params_per_bucket.append(bucket_params)
    if overlap_info.shard_buckets:
        # Additionally save the bucket size for the assignment heuristic to use
        bucket_size = 0
        for param in bucket_params:
            bucket_size += param.numel()
        assert overlap_info.total_size is not None
        overlap_info.total_size += bucket_size
Exemple #3
0
    def hook_with_zero_fn(
        state: Any,
        bucket: dist.GradBucket,
    ) -> torch.futures.Future[torch.Tensor]:
        r"""
        Returns a :class:`Future` that gives a gradient bucket tensor and
        performs the equivalent of a :class:`ZeroRedundancyOptimizer`
        :meth:`step` if ``bucket`` is the last gradient bucket.

        The function performs additional computation on the iteration that
        the :class:`DistributedDataParallel` buckets are rebuilt to collect
        information used to implement the modified hook.

        Arguments:
            state (Any): any state for the hook.
            bucket (dist.GradBucket): the :class:`DistributedDataParallel`
                gradient bucket.
        """
        fut = hook(state, bucket)
        overlap_info = zero._overlap_info
        bucket_index = bucket.index()

        # Proceed as normal until the DDP buckets have been rebuilt
        if not ddp._has_rebuilt_buckets:
            assert overlap_info.status == _OverlapStatus.UNINITIALIZED
            return fut

        if overlap_info.status == _OverlapStatus.UNINITIALIZED:
            overlap_info.status = _OverlapStatus.DDP_HAS_REBUILT_BUCKETS

        rank = zero.global_rank
        rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)

        # Once DDP buckets have been rebuilt but ZeRO has not been
        # properly initialized yet, collect the information needed
        if overlap_info.status == _OverlapStatus.DDP_HAS_REBUILT_BUCKETS:
            bucket_params = bucket.parameters()
            assert len(bucket_params) > 0, "Empty bucket"
            params_per_rank = overlap_info.params_per_rank
            params_per_bucket = overlap_info.params_per_bucket
            if rank_to_update == rank:
                overlap_info.offsets[bucket_index] = len(
                    params_per_rank[rank_to_update])
            params_per_rank[rank_to_update].extend(bucket_params)
            params_per_bucket.append(bucket_params)
            return fut

        assert overlap_info.status == _OverlapStatus.INITIALIZED

        # Save the bucket reference and all-reduce future for the final bucket
        if rank_to_update == rank:
            overlap_info.bucket_index_to_bucket[bucket_index] = bucket
            overlap_info.bucket_index_to_future[bucket_index] = fut

        # NOTE: The implementation from this point forward assumes that the
        # buckets are indexed incrementally starting from 0 in the order of
        # their autograd hooks firing
        num_buckets = len(overlap_info.params_per_bucket)
        is_last_bucket = bucket_index == num_buckets - 1
        if not is_last_bucket:
            return fut

        # Perform partial optimizer step on all buckets after the final
        # bucket has been computed
        # NOTE: This should not be chained as a callback to the last bucket's
        # all-reduce future since that would add synchronization that delays
        # all optimizer computation to wait for that last all-reduce
        for bucket_index in range(num_buckets):
            rank_to_update = zero._ddp_bucket_index_to_rank(bucket_index)
            num_local_optim_params = len(zero.optim.param_groups[0]["params"])
            if rank_to_update == rank:
                gradients: List[Optional[torch.Tensor]] = \
                    [_NO_PARAM_UPDATE for _ in range(num_local_optim_params)]
                assert bucket_index in overlap_info.offsets, \
                    f"Bucket index {bucket_index} was not assigned to rank {rank}"
                offset = overlap_info.offsets[bucket_index]
                # Ensure that the all-reduce completes before performing the
                # the parameter update
                assert bucket_index in overlap_info.bucket_index_to_future, \
                    f"All-reduce future for bucket {bucket_index} not saved " \
                    f"on rank {rank}"
                allreduce_future = overlap_info.bucket_index_to_future[
                    bucket_index]
                allreduce_future.wait()
                bucket_gradients = overlap_info.bucket_index_to_bucket[
                    bucket_index].gradients()
                for i, grad in enumerate(bucket_gradients):
                    gradients[offset + i] = grad
                zero._local_step(gradients)
            device = overlap_info.params_per_bucket[bucket_index][0].device
            device_index = zero._device_to_device_index[device]
            assert bucket_index in zero._buckets[device_index][rank_to_update]
            overlap_info.broadcast_handles.append(
                dist.broadcast(
                    zero._buckets[device_index][rank_to_update][bucket_index],
                    src=rank_to_update,
                    async_op=True))

        # Ensure that all parameter updates are finished before the
        # next forward pass
        assert len(overlap_info.broadcast_handles) == num_buckets, \
            f"Missing at least one broadcast handle on rank {rank}"
        _ = list(map(lambda x: x.wait(), overlap_info.broadcast_handles))
        overlap_info.broadcast_handles.clear()

        # Reset per-iteration information
        overlap_info.bucket_index_to_future.clear()
        overlap_info.bucket_index_to_bucket.clear()

        return fut