Exemplo n.º 1
0
    def _init_from_local_tensor(
        cls,
        local_tensor: torch.Tensor,
        sharding_spec: shard_spec.ShardingSpec,
        *global_size: Sequence[int],
        process_group: dist.ProcessGroup = None,
        init_rrefs=False,
    ) -> "ShardedTensor":
        """
        Initialize a ShardedTensor given only one local tensor, global sharded tensor
        size and sharding spec on each rank.

        Args:
            local_tensor (Tensor): Single tensor of local shard stored in each rank.
            sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how to shard the Tensor.
            global_size (Sequence[int]): Size of the sharded tensor.
            process_group (ProcessGroup, optional): The process group to aggregate on.
                Default: None
            init_rrefs (bool, optional): Whether or not to initialize
                :class:`torch.distributed.rpc.RRef`s pointing to remote shards.
                Need to initialize the RPC Framework if specified as ``True``.
                Default: ``False``.

        Returns:
            A :class:`ShardedTensor` sharded based on the given sharding_spec with local
                tensor stored in the current rank.

        Examples:
            >>> # All tensors below are of torch.int64 type.
            >>> # We have 2 process groups, 2 ranks.
            >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank
            >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2]))
            >>> local_tensor
            tensor([[1, 2, 3, 4]]) # Rank 0
            tensor([[3, 4, 5, 6]]) # Rank 1
            >>> sharding_dim = 0
            >>> sharding_spec = ChunkShardingSpec(
                    dim=sharding_dim,
                    placements=[
                        "rank:0/cuda:0",
                        "rank:1/cuda:1",
                    ],
                )
            >>> st = ShardedTensor._init_from_local_tensor(local_tensor, sharding_spec, [2, 4])
            >>> st
            ShardedTensor(
                ShardedTensorMetadata(
                    shards_metadata=[
                        ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0),
                        ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1),
                    ],
                    size=torch.Size([2, 4])
            )
            >>> st.local_tensor()
            tensor([1, 2, 3, 4]) # Rank 0
            tensor([3, 4, 5, 6]) # Rank 1

        Warning: This API is experimental and subject to change. It lacks of a fully across
                 rank validations, and we only validate the local shard on the current rank.
                 We fully rely on the user to ensure local tensor is sharded based on the
                 sharding spec.
        """
        if not local_tensor.is_contiguous():
            raise ValueError('local_tensor is not a contiguous Tensor.')

        global_tensor_size = _flatten_tensor_size(global_size)
        tensor_properties = TensorProperties(
            dtype=local_tensor.dtype,
            layout=local_tensor.layout,
            requires_grad=local_tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=local_tensor.is_pinned())
        sharded_tensor_metadata = sharding_spec.build_metadata(
            global_tensor_size, tensor_properties)

        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)

        local_shards: List[Shard] = []
        for shard_metadata in sharded_tensor_metadata.shards_metadata:
            rank, device = _parse_and_validate_remote_device(
                process_group, shard_metadata.placement)
            if rank == current_rank:
                local_shards.append(Shard(local_tensor, shard_metadata))

        # TODO: figure out what the API should behave when some rank have no shard
        # see https://github.com/pytorch/pytorch/issues/7313
        return ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards,
            sharded_tensor_metadata,
            process_group=process_group,
            init_rrefs=init_rrefs,
            sharding_spec=sharding_spec,
        )
Exemplo n.º 2
0
    def _init_from_local_shards(
        cls,
        local_shards: List[Shard],
        *global_size,
        process_group=None,
        init_rrefs=False,
    ):
        # STEP 1: Validate the Shardmetadatas locally
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)
        world_size = dist.get_world_size(process_group)

        local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None
        global_tensor_size = _flatten_tensor_size(global_size)

        if len(local_shards) > 0:
            local_sharded_tensor_metadata = \
                build_metadata_from_local_shards(local_shards, global_tensor_size, current_rank, process_group)

        # STEP 2. Validate metadata across ranks, and build a global sharded tensor
        # metadata by gathering local ShardedTensorMetadata
        gathered_metadatas: List[Optional[ShardedTensorMetadata]] = []
        if world_size > 1:
            gathered_metadatas = [None for _ in range(world_size)]

            if isinstance(process_group, dist.ProcessGroupNCCL):
                # with GPU/NCCL, we need to set a device for all_gather_object
                # to use as we need to know which device we should put the
                # serialized tensor on before the NCCL collective.
                with torch.cuda.device(current_rank):
                    dist.all_gather_object(gathered_metadatas,
                                           local_sharded_tensor_metadata,
                                           group=process_group)
            else:
                dist.all_gather_object(gathered_metadatas,
                                       local_sharded_tensor_metadata,
                                       group=process_group)
        else:
            gathered_metadatas = [local_sharded_tensor_metadata]

        global_sharded_tensor_metadata = build_global_metadata(
            gathered_metadatas)

        # STEP 3: Validation done, create the actual ShardedTensor and populate fields
        # prepare initialization
        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        # add to metadata and local_shards
        sharded_tensor._metadata = global_sharded_tensor_metadata
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(
            global_sharded_tensor_metadata.shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Exemplo n.º 3
0
    def _init_from_local_shards_and_global_metadata(
        cls,
        local_shards: List[Shard],
        sharded_tensor_metadata: ShardedTensorMetadata,
        process_group=None,
        init_rrefs=False,
    ):
        """
        Initialize a ShardedTensor with local shards and a global
        ShardedTensorMetadata built on each rank.

        Warning: This API is experimental and subject to change. It does
                 not do cross rank validations, and fully rely on the user
                 for the correctness of sharded_tensor_metadata on each rank
        """
        process_group = (process_group if process_group is not None else
                         distributed_c10d._get_default_group())
        current_rank = dist.get_rank(process_group)

        shards_metadata = sharded_tensor_metadata.shards_metadata
        tensor_properties = sharded_tensor_metadata.tensor_properties

        if len(shards_metadata) == 0:
            raise ValueError("shards_metadata must not be empty!")

        if tensor_properties.layout != torch.strided:
            raise ValueError(
                'Only torch.strided layout is currently supported')

        sharded_tensor = cls.__new__(cls)
        sharded_tensor._prepare_init(process_group=process_group,
                                     init_rrefs=init_rrefs)

        sharded_tensor._metadata = sharded_tensor_metadata

        local_shard_metadatas = []

        def _raise_if_mismatch(expected,
                               actual,
                               prop_name,
                               rank,
                               is_property=False):
            tensor_property_or_metadata = "tensor property" if is_property else "local ShardMetadata"
            if expected != actual:
                raise ValueError(
                    f"Local shards' tensor {prop_name} property is incompatible with "
                    f"{tensor_property_or_metadata} on rank {rank}: "
                    f"{tensor_property_or_metadata} {prop_name}={expected}, "
                    f"local shard tensor {prop_name}={actual}.")

        # collect local shard metadatas from the global sharded_tensor_metadata
        for shard_metadata in shards_metadata:  # type: ignore[attr-defined]
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_metadata.placement)

            if current_rank == rank:
                local_shard_metadatas.append(shard_metadata)

        if len(local_shards) != len(local_shard_metadatas):
            raise RuntimeError(
                f'Number of local shards ({len(local_shards)}) does not match number of local '
                f'shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) '
                f'on rank ({current_rank}) ')

        for shard in local_shards:
            shard_meta = shard.metadata
            local_shard_tensor = shard.tensor
            rank, local_device = _parse_and_validate_remote_device(
                sharded_tensor._process_group, shard_meta.placement)

            # validate if shard_meta in the metadatas collected from sharded_tensor_metadata
            assert shard_meta in local_shard_metadatas, \
                "local shard metadata not in sharded_tensor_metadata!"

            _raise_if_mismatch(tensor_properties.layout,
                               local_shard_tensor.layout, "layout",
                               current_rank, True)
            if not local_shard_tensor.is_contiguous():
                raise ValueError(
                    'Only torch.contiguous_format memory_format is currently supported'
                )

            _raise_if_mismatch(shard_meta.shard_sizes,
                               list(local_shard_tensor.size()), "size",
                               current_rank)
            _raise_if_mismatch(tensor_properties.pin_memory,
                               local_shard_tensor.is_pinned(), "pin_memory",
                               current_rank, True)
            _raise_if_mismatch(local_device, local_shard_tensor.device,
                               "device", current_rank)
            _raise_if_mismatch(tensor_properties.dtype,
                               local_shard_tensor.dtype, "dtype", current_rank,
                               True)
            _raise_if_mismatch(tensor_properties.requires_grad,
                               local_shard_tensor.requires_grad,
                               "requires_grad", current_rank, True)

        # check if shards_metadata have overlap shards
        validate_non_overlapping_shards_metadata(shards_metadata)

        # check if the shards_metadata is compatible with overall size of the sharded tensor.
        check_tensor(shards_metadata, list(sharded_tensor_metadata.size))

        # done validation, add local_shards
        sharded_tensor._local_shards = local_shards
        # make a EnumerableShardingSpec for sharded tensors that initialized from this API.
        # TODO: make sharding spec a ChunkShardingSpec by inferring from the metadata list.
        #       see issue https://github.com/pytorch/pytorch/issues/67244
        sharded_tensor._sharding_spec = EnumerableShardingSpec(shards_metadata)

        # run post initialization, i.e. map registration, rpc initialization
        sharded_tensor._post_init()
        return sharded_tensor
Exemplo n.º 4
0
 def process_group(self) -> Optional[ProcessGroup]:
     if self._process_group is None:
         # The strategy should have already initilized process group in setup_environment()
         self._process_group = _get_default_group()
     return self._process_group
Exemplo n.º 5
0
 def __setstate__(self, state):
     # If serializable, then the process group should be the default one
     self.process_group = _get_default_group()
     self.check_previous_reduction = False
     super(DistributedDataParallel, self).__setstate__(state)
     self._ddp_init_helper()
Exemplo n.º 6
0
def shard_parameter(
        module: torch.nn.Module,
        param_name: str,
        sharding_spec: ShardingSpec,
        src_rank=0,
        process_group=None):
    """
    Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
    module, it shards that parameter according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    This method replaces ``module.param_name`` with a
    :class:`torch.distributed._sharded_tensor.ShardedTensor`

    Args:
        module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
        param_name (str): Name of the parameter of ``module`` that needs to be sharded.
        sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    .. warning::
        Only :class:`torch.distributed._sharding_spec.ShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    # Perform some validation first.
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise ValueError('Only ChunkShardingspec is supported.')

    if not hasattr(module, param_name):
        raise ValueError(f'module: {module} does not have parameter with name: {param_name}')

    tensor = getattr(module, param_name)
    if not isinstance(tensor, torch.Tensor):
        raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')

    if not tensor.is_contiguous():
        raise ValueError(f'param: {param_name} is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    with torch.cuda.device(tensor.device):
        dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    st = ShardedTensor._init_from_local_shards(local_shards, tensor.size(), process_group=pg)

    # Manually set sharding_spec
    st._sharding_spec = sharding_spec

    # Replace param with ShardedTensor.

    # Need to delete the attribute first since param_name might be
    # torch.nn.Parameter and can't be replaced with ShardedTensor which is
    # not torch.nn.Parameter.
    delattr(module, param_name)

    # Now we can set the attribute appropriately.
    setattr(module, param_name, st)
Exemplo n.º 7
0
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 find_unused_parameters=False,
                 check_reduction=False):

        super(DistributedDataParallel, self).__init__()

        assert any((p.requires_grad for p in module.parameters())), (
            "DistributedDataParallel is not needed when a module "
            "doesn't have any parameter that requires a gradient.")

        self.is_multi_device_module = len(
            {p.device
             for p in module.parameters()}) > 1
        distinct_device_types = {p.device.type for p in module.parameters()}
        assert len(distinct_device_types) == 1, (
            "DistributedDataParallel's input module must be on "
            "the same type of devices, but input module parameters locate in {}."
        ).format(distinct_device_types)
        self.device_type = list(distinct_device_types)[0]

        if self.device_type == "cpu" or self.is_multi_device_module:
            assert not device_ids and not output_device, (
                "DistributedDataParallel device_ids and output_device arguments "
                "only work with single-device GPU modules, but got "
                "device_ids {}, output_device {}, and module parameters {}."
            ).format(device_ids, output_device,
                     {p.device
                      for p in module.parameters()})

            self.device_ids = None
            self.output_device = None
        else:
            # Use all devices by default for single-device GPU modules
            if device_ids is None:
                device_ids = _get_all_device_indices()

            self.device_ids = list(
                map(lambda x: _get_device_index(x, True), device_ids))

            if output_device is None:
                output_device = device_ids[0]

            self.output_device = _get_device_index(output_device, True)

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.device = list(self.module.parameters())[0].device
        self.broadcast_buffers = broadcast_buffers
        self.find_unused_parameters = find_unused_parameters
        self.require_backward_grad_sync = True
        self.require_forward_param_sync = True
        self.ddp_join_enabled = False

        if check_reduction:
            # This argument is no longer used since the reducer
            # will ensure reduction completes even if some parameters
            # do not receive gradients.
            pass

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = int(250 * 1024 * 1024)

        # reduction bucket size
        self.bucket_bytes_cap = int(bucket_cap_mb * 1024 * 1024)

        # Sync params and buffers
        self._sync_params_and_buffers(authoritative_rank=0)

        self._ddp_init_helper()
Exemplo n.º 8
0
    def __init__(self,
                 module,
                 device_ids=None,
                 output_device=None,
                 dim=0,
                 broadcast_buffers=True,
                 process_group=None,
                 bucket_cap_mb=25,
                 find_unused_parameters=False,
                 check_reduction=False):

        super(DistributedDataParallel, self).__init__()

        assert any((p.requires_grad for p in module.parameters())), (
            "DistributedDataParallel is not needed when a module "
            "doesn't have any parameter that requires a gradient.")

        self.is_multi_device_module = len(
            {p.device
             for p in module.parameters()}) > 1
        self.is_cuda = all(
            [p.device.type == 'cuda' for p in module.parameters()])

        if not self.is_cuda or self.is_multi_device_module:
            assert not device_ids and not output_device, (
                "DistributedDataParallel device_ids and output_device arguments "
                "only work with single-device CUDA modules, but got "
                "device_ids {}, output_device {}, and module parameters {}."
            ).format(device_ids, output_device,
                     {p.device
                      for p in module.parameters()})

            self.device_ids = None
            self.output_device = None
        else:
            # Use all devices by default for single-device CUDA modules
            if device_ids is None:
                device_ids = list(range(torch.cuda.device_count()))

            self.device_ids = list(
                map(lambda x: _get_device_index(x, True), device_ids))

            if output_device is None:
                output_device = device_ids[0]

            self.output_device = _get_device_index(output_device, True)

        if self.is_multi_device_module:
            assert self.is_cuda, (
                "DistributedDataParallel with multi-device module only works "
                "with CUDA devices, but module parameters locate in {}."
            ).format({p.device
                      for p in module.parameters()})

        if process_group is None:
            self.process_group = _get_default_group()
        else:
            self.process_group = process_group

        self.dim = dim
        self.module = module
        self.broadcast_buffers = broadcast_buffers
        self.find_unused_parameters = find_unused_parameters
        self.require_backward_grad_sync = True
        self.require_forward_param_sync = True

        if check_reduction:
            # This argument is no longer used since the reducer
            # will ensure reduction completes even if some parameters
            # do not receive gradients.
            pass

        MB = 1024 * 1024

        # used for intra-node param sync and inter-node sync as well
        self.broadcast_bucket_size = int(250 * MB)

        # reduction bucket size
        self.bucket_bytes_cap = int(bucket_cap_mb * MB)

        # Sync params and buffers
        module_states = list(self.module.state_dict().values())
        if len(module_states) > 0:
            self._distributed_broadcast_coalesced(module_states,
                                                  self.broadcast_bucket_size)

        self._ddp_init_helper()
Exemplo n.º 9
0
def _shard_tensor(tensor: torch.Tensor,
                  sharding_spec: ShardingSpec,
                  src_rank=0,
                  process_group=None):
    """
    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Returns:
        A :class:`ShardedTensor` sharded from the given tensor.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise NotImplementedError('Only ChunkShardingspec is supported.')
    if not tensor.is_contiguous():
        raise ValueError('input tensor is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    return ShardedTensor._init_from_local_shards(local_shards,
                                                 tensor.size(),
                                                 process_group=pg)
Exemplo n.º 10
0
    def __init__(self,
                 params,
                 lr=1e-3,
                 bias_correction=True,
                 betas=(0.9, 0.999),
                 eps=1e-8,
                 weight_decay=0.,
                 amsgrad=False,
                 dtype=torch.float32,
                 grad_sync_dtype=None,
                 param_sync_dtype=None,
                 device='cuda',
                 process_group=None,
                 distributed_process_group=None,
                 redundant_process_group=None,
                 model_parallel=False,
                 model_parallel_rank=0,
                 average_grad_sync=True,
                 overlap_grad_sync=True,
                 bucket_cap_mb=15,
                 pipeline_size=2,
                 fused_grad_copy=False,
                 max_grad_norm=0.,
    ):
        defaults = dict(lr=lr, bias_correction=bias_correction,
                        betas=betas, eps=eps, weight_decay=weight_decay)
        super(DistributedFusedAdam, self).__init__(params, defaults)

        # Adam options
        if amsgrad:
            raise RuntimeError('DistributedFusedAdam does not support the AMSGrad variant.')

        # Datatype options
        if grad_sync_dtype is None:
            grad_sync_dtype = dtype
        if param_sync_dtype is None:
            param_sync_dtype = grad_sync_dtype
        valid_dtypes = [
            (torch.float32, torch.float16, torch.float16),
            (torch.float32, torch.float32, torch.float32),
        ]
        if (dtype, grad_sync_dtype, param_sync_dtype) not in valid_dtypes:
            raise RuntimeError(
                'Invalid dtypes for DistributedFusedAdam '
                f'(dtype={dtype}, '
                f'grad_sync_dtype={grad_sync_dtype}, '
                f'param_sync_dtype={param_sync_dtype}))')
        if device != 'cuda':
            raise RuntimeError('DistributedFusedAdam only supports GPU')
        self.dtype = dtype
        self.grad_sync_dtype = grad_sync_dtype
        self.param_sync_dtype = param_sync_dtype
        self.device = device

        # Process groups
        self.world_process_group = (
            _get_default_group()
            if process_group is None
            else process_group
        )
        self.distributed_process_group = (
            self.world_process_group
            if distributed_process_group is None
            else distributed_process_group
        )
        self.redundant_process_group = redundant_process_group
        self.world_size = torch.distributed.get_world_size(self.world_process_group)
        self.distributed_rank = torch.distributed.get_rank(self.distributed_process_group)
        self.distributed_size = torch.distributed.get_world_size(self.distributed_process_group)
        self.redundant_size = (
            1
            if self.redundant_process_group is None
            else torch.distributed.get_world_size(self.redundant_process_group)
        )
        if (self.world_size != self.distributed_size * self.redundant_size):
            raise RuntimeError(
                'Invalid process group configuration '
                f'(world process group size = {self.world_size}, '
                f'distributed process group size = {self.distributed_size}, '
                f'redundant process group size = {self.redundant_size})'
            )
        self.model_parallel = model_parallel
        self.model_parallel_rank = model_parallel_rank

        # Grad sync options
        if fused_grad_copy:
            _params = list(self.parameters())
            if (_params
                and any(p.dtype != self.grad_sync_dtype for p in _params)
                and any(p.device != self.device for p in _params)):
                raise RuntimeError(
                    'Attempted to use fused gradient copy in DistributedFusedAdam, '
                    'but parameters do not all have expected '
                    f'dtype ({self.grad_sync_dtype}) and device ({self.device})'
                )
        self.average_grad_sync = average_grad_sync
        self.overlap_grad_sync = overlap_grad_sync
        self.pipeline_size = pipeline_size
        self.fused_grad_copy = fused_grad_copy

        # Grad clipping options
        self.max_grad_norm = max_grad_norm

        # Determine bucket sizes
        dtype_size = torch.finfo(self.grad_sync_dtype).bits // 8
        self.alignment = 128 // dtype_size
        bucket_size = 1024*1024*bucket_cap_mb / dtype_size
        shard_size = bucket_size / self.distributed_size
        shard_size = (int(shard_size) // self.alignment) * self.alignment
        shard_size = max(shard_size, self.alignment)
        bucket_size = shard_size * self.distributed_size
        self.bucket_size = bucket_size
        self.shard_size = shard_size

        # Load CUDA kernels
        global fused_adam_cuda, distributed_adam_cuda
        fused_adam_cuda = importlib.import_module("fused_adam_cuda")
        distributed_adam_cuda = importlib.import_module("distributed_adam_cuda")

        # Optimizer state
        self.state['buckets'] = []
        self.state['step'] = 0

        # Objects for gradient synchronization
        self._grads_generated = set()
        self._grads_to_copy = []
        self._pipeline_streams = [torch.cuda.Stream() for _ in range(self.pipeline_size)]

        # Check if collectives have no_copy option
        self._reduce_scatter_no_copy = (
            'no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args
        )
        self._all_gather_no_copy = (
            'no_copy' in inspect.getfullargspec(torch.distributed.all_gather).args
        )

        # Attach hooks for gradient synchronization
        self._register_post_backward_hooks()
Exemplo n.º 11
0
    def __init__(self,
                 module: nn.Module,
                 process_group: Optional[ProcessGroup] = None,
                 cpu_offload: Optional[CPUOffload] = None):
        torch._C._log_api_usage_once("torch.distributed.fsdp")
        super().__init__()
        self.process_group = process_group or _get_default_group()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        # device for computation, if module is on GPU, use module.device;
        # if module is on CPU, use current device;
        self.compute_device = _get_default_cuda_device(module)
        self.compute_dtype = _get_data_type(module)

        # Free full params and keep shard only after forward
        self.reshard_after_forward = True

        # setting two factors to avoid underflow and overflow
        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(
            self.world_size)
        self.gradient_postdivide_factor: float = (
            self.world_size / self.gradient_predivide_factor)

        self.numel_padded_per_param: List[int] = []
        self.cpu_offload = cpu_offload or CPUOffload()

        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        params = []
        for param_name, param in module.named_parameters():
            if not hasattr(param, "_is_sharded"):
                params.append(param)

        self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
            module, param_list=params)
        del module  # free original module in case it helps garbage collection
        if self._fsdp_wrapped_module.flat_param is not None:
            self.params = [self._fsdp_wrapped_module.flat_param]
        else:
            self.params = []

        # Shard module parameters in place
        self._shard_parameters()

        # Make sure all parameters are sharded.
        for n, p in self.named_parameters():
            if not hasattr(p, "_is_sharded"):
                raise RuntimeError(
                    f"found unsharded parameter: {n} ; {p.size()}")
        self._reset_lazy_init()

        # Enum to indicate if we're in the forward/backward pass, idle, etc.
        self.training_state = TrainingState_.IDLE

        # Flag to guard against preparing gradients multiple times per backward pass.
        self._pre_backward_hook_has_run = False

        # If specified, offload parameter shard to CPU.
        if self.cpu_offload.offload_params:
            for p in self.params:
                self._offload_to_cpu(p)
Exemplo n.º 12
0
    def __init__(
        self,
        sharding_spec: ShardingSpec,
        *size,
        dtype=None,
        layout=torch.strided,
        requires_grad=False,
        pin_memory=False,
        memory_format=torch.contiguous_format,
        process_group=None,
    ):
        self._rpc_initialized = False
        self._sharded_tensor_id = None
        if rpc._is_current_rpc_agent_set():
            # Validate PG and RPC ranks match.
            pg_rank = dist.get_rank()
            rpc_rank = rpc.get_worker_info().id
            if pg_rank != rpc_rank:
                raise ValueError(
                    f'Default ProcessGroup and RPC ranks must be '
                    f'the same for ShardedTensor, found process group rank: '
                    f'{pg_rank} and RPC rank: {rpc_rank}')

        if layout != torch.strided:
            raise ValueError(
                'Only torch.strided layout is currently supported')

        if memory_format != torch.contiguous_format:
            raise ValueError(
                'Only torch.contiguous_format memory_format is currently supported'
            )

        self._sharding_spec = sharding_spec
        self._dims = list(size)
        self._process_group = (process_group if process_group is not None else
                               distributed_c10d._get_default_group())

        if distributed_c10d._rank_not_in_group(self._process_group):
            raise ValueError(
                f'Global rank: {dist.get_rank()} not part of process group')

        self._local_shards: List[Shard] = []
        self._remote_shards: Dict[int, List[rpc.RRef[Shard]]] = {}
        self._sharding_metadata: List[ShardMetadata] = []
        if isinstance(self._sharding_spec, ChunkShardingSpec):
            self._init_chunked(
                dtype,
                layout,
                requires_grad,
                pin_memory,
                memory_format,
            )
        elif isinstance(self._sharding_spec, EnumerableShardingSpec):
            self._init_enumerable(
                dtype,
                layout,
                requires_grad,
                pin_memory,
                memory_format,
            )
        else:
            raise ValueError(
                f'Unsupported sharding_spec: {self._sharding_spec}')

        with _sharded_tensor_lock:
            global _sharded_tensor_current_id, _sharded_tensor_map
            self._sharded_tensor_id = _sharded_tensor_current_id
            _sharded_tensor_map[self._sharded_tensor_id] = self
            _sharded_tensor_current_id += 1

        # Initialize RPC if available.
        if rpc._is_current_rpc_agent_set():
            self._init_rpc()
Exemplo n.º 13
0
 def __setstate__(self, state):
     # If serializable, then the process group should be the default one
     self.process_group = _get_default_group()
     self.check_previous_reduction = False
     super(RandomKSparsifiedDDP, self).__setstate__(state)
     self._ddp_init_helper()
Exemplo n.º 14
0
    def __init__(self, params,
                 lr=1e-3, bias_correction = True, grad_averaging=True,
                 betas=(0.9, 0.999), eps=1e-8, 
                 weight_decay=0., max_grad_norm=0., 
                 adam_w_mode=True, use_nvlamb=False,
                 step_supports_amp_scaling=True, overlap_reductions=True,
                 dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
                 dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, 
                 e5m2_allgather=False, verbose=False):
        defaults = dict(lr=lr, bias_correction=bias_correction,
                        betas=betas, eps=eps, weight_decay=weight_decay,
                        grad_averaging=grad_averaging,
                        max_grad_norm=max_grad_norm)

        super(DistributedFusedLAMB, self).__init__(params, defaults)

        global fused_adam_cuda, distributed_lamb_cuda
        fused_adam_cuda = importlib.import_module("fused_adam_cuda")
        distributed_lamb_cuda = importlib.import_module("distributed_lamb_cuda")

        self._overflow_buf = torch.cuda.IntTensor([0])
        self._has_overflow = False
        self.multi_tensor_lamb_compute_update_term = distributed_lamb_cuda.multi_tensor_lamb_compute_update_term
        self.multi_tensor_lamb_update_weights = distributed_lamb_cuda.multi_tensor_lamb_update_weights
        import amp_C
        self.multi_tensor_l2norm = amp_C.multi_tensor_l2norm

        self._grad_averaging = grad_averaging
        self._adam_w_mode = 1 if adam_w_mode else 0
        self._use_nvlamb = use_nvlamb
        self._step_supports_amp_scaling = step_supports_amp_scaling
        self._is_accumulation_step = False
        self._last_step = False
        self._overlap_reductions = overlap_reductions
        self._global_scale = None
        self._num_blocks = dwu_num_blocks
        self._num_chunks = dwu_num_chunks
        self._e5m2_allgather = e5m2_allgather
        self._verbose = verbose
        self._L2_grad_norm = None
        
        self._current_process_group = c10d._get_default_group()
        self._available_ranks = list(c10d._pg_group_ranks[self._current_process_group].keys())
        self._group_size = torch.cuda.device_count() if dwu_group_size <= 0 else dwu_group_size
        self._world_size = torch.distributed.get_world_size()
        self._num_groups = self._world_size // self._group_size
        self._rank_in_group = torch.distributed.get_rank() % self._group_size

        self._lr = torch.tensor(0.0, dtype=torch.float32, device='cuda')

        self._resume_from_checkpoint = False
        self._step = torch.cuda.IntTensor([0])

        # Master weight, moment, gradient buffers
        self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None

        import inspect
        assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"

        self._num_rs_pg = dwu_num_rs_pg
        self._num_ar_pg = dwu_num_ar_pg
        self._num_ag_pg = dwu_num_ag_pg
        if self._num_groups > 1:
            self._ar_pg = []
            for dev_i in range(self._group_size):
                ranks = [dev_i+j*self._group_size for j in range(self._num_groups)]
                for i in range(self._num_ar_pg):
                    if self._verbose:
                        print(f"creating new group {i}: {ranks}")
                    grp = torch.distributed.new_group(ranks=ranks)
                    if grp != torch.distributed.GroupMember.NON_GROUP_MEMBER:
                        if self._verbose:
                            print(f"group {i}: init barrier (device: {torch.cuda.current_device()})")
                        torch.distributed.barrier(group=grp, device_ids=[torch.cuda.current_device()])
                    if self._verbose:
                        print(f"created new group {i}")

                    if torch.distributed.get_rank() in ranks:
                        self._ar_pg.append(grp)
            self._ar_st = [torch.cuda.Stream() for _ in range(self._num_ar_pg)]
            #for ar_pg in self._ar_pg:
            #    torch.distributed.all_reduce(self._overflow_buf,group=ar_pg)
        rs_ranks = []
        for group_i in range(self._num_groups):
            rs_ranks.append([group_i*self._group_size+j for j in range(self._group_size)])
        self._rs_pg = []
        for group_i in range(self._num_groups):
            ranks = rs_ranks[group_i]
            for i in range(self._num_rs_pg):
                grp = torch.distributed.new_group(ranks=ranks)
                if torch.distributed.get_rank() in ranks:
                    self._rs_pg.append(grp)
            l2_grad_norm_pg = torch.distributed.new_group(ranks=ranks)
            if torch.distributed.get_rank() in ranks:
                self._l2_grad_norm_pg = l2_grad_norm_pg
                #torch.distributed.all_reduce(self._overflow_buf,group=self._l2_grad_norm_pg)
        self._rs_st = [torch.cuda.Stream() for _ in range(self._num_rs_pg)]
        #for rs_pg in self._rs_pg:
        #    torch.distributed.all_reduce(self._overflow_buf,group=rs_pg)
        if self._num_ag_pg == 0:
            self._ag_pg = self._rs_pg
            self._ag_st = self._rs_st
            self._num_ag_pg = self._num_rs_pg
        else:
            self._ag_pg = []
            for group_i in range(self._num_groups):
                ranks = rs_ranks[group_i]
                for i in range(self._num_ag_pg):
                    grp = torch.distributed.new_group(ranks=ranks)
                    if torch.distributed.get_rank() in ranks:
                        self._ag_pg.append(grp)
            self._ag_st = [torch.cuda.Stream() for _ in range(self._num_ag_pg)]
            #for ag_pg in self._ag_pg:
            #    torch.distributed.all_reduce(self._overflow_buf,group=ag_pg)
        self._l2_grad_norm_st = torch.cuda.Stream()
        self._completion_st = torch.cuda.Stream()
        self._step.record_stream(self._completion_st)

        self._reductions_works = [None]*self._num_blocks
        self._allgather_works = [None]*self._num_blocks

        self._one = torch.cuda.IntTensor([1])

        self._first_step = True
        self._lazy_init_stage1_done, self._lazy_init_stage2_done = False, False
        self._param_order = self.AtomicCounter()
Exemplo n.º 15
0
    def __init__(
        self,
        module: nn.Module,
        process_group: Optional[ProcessGroup] = None,
        cpu_offload: Optional[CPUOffload] = None,
        fsdp_auto_wrap_policy: Optional[Callable] = None,
    ):
        torch._C._log_api_usage_once("torch.distributed.fsdp")
        super().__init__()
        # if fsdp_auto_wrap_policy is specified, submodules should not be
        # already wrapped, otherwise we'd attempt to double wrap them resulting
        # in errors.
        if fsdp_auto_wrap_policy is not None:
            self._check_wrapped(
                module,
                check_fn=lambda mod: not isinstance(mod,
                                                    FullyShardedDataParallel),
                err_fn=lambda mod:
                f"Expected {mod} to NOT be FullyShardedDataParallel if auto_wrap is enabled.",
            )
            # TODO: refactor recursive_wrap so that it is not dependent on
            # ConfigAutoWrap.
            config_auto_wrap = ConfigAutoWrap(
                auto_wrap_policy=fsdp_auto_wrap_policy,
                wrapper_cls=FullyShardedDataParallel  # type: ignore[arg-type]
            )
            with config_auto_wrap:
                assert ConfigAutoWrap.in_autowrap_context
                assert ConfigAutoWrap.wrapper_cls == FullyShardedDataParallel
                assert ConfigAutoWrap.auto_wrap_policy == fsdp_auto_wrap_policy
                # This will only wrap the children, and then constructor will
                # run for root module.
                ConfigAutoWrap.recursive_wrap(
                    module,
                    auto_wrap_policy=fsdp_auto_wrap_policy,
                    # Note that we have the recursive_wrap skip wrapping for
                    # the outermost (this) module otherwise it will result in a
                    # double-wrap causing issues.
                    only_wrap_children=True,
                    # FSDP arguments follow.
                    process_group=process_group,
                    cpu_offload=cpu_offload,
                    # Note that recursive_wap should not call FSDP with wrapping
                    # enabled, as this recursive call handles all wrapping,
                    # including for nested children.
                    fsdp_auto_wrap_policy=None,
                )
            assert not ConfigAutoWrap.in_autowrap_context

        self.process_group = process_group or _get_default_group()
        self.rank = self.process_group.rank()
        self.world_size = self.process_group.size()
        # device for computation, if module is on GPU, use module.device;
        # if module is on CPU, use current device;
        self.compute_device = _get_default_cuda_device(module)

        # Free full params and keep shard only after forward
        self.reshard_after_forward = True

        # setting two factors to avoid underflow and overflow
        self.gradient_predivide_factor: float = self._get_gradient_predivide_factor(
            self.world_size)
        self.gradient_postdivide_factor: float = (
            self.world_size / self.gradient_predivide_factor)

        self.numel_padded_per_param: List[int] = []
        self.cpu_offload = cpu_offload or CPUOffload()

        # Only handle params which are not already sharded. This enables
        # sharding individual layers of a Module, with an outer wrapper to
        # shard any leftover parameters.
        params = []
        for param_name, param in module.named_parameters():
            if not hasattr(param, "_is_sharded"):
                params.append(param)

        self._fsdp_wrapped_module: nn.Module = FlattenParamsWrapper(
            module, param_list=params)
        del module  # free original module in case it helps garbage collection
        if self._fsdp_wrapped_module.flat_param is not None:
            self.params = [self._fsdp_wrapped_module.flat_param]
        else:
            self.params = []

        # Shard module parameters in place
        self._shard_parameters()

        # Make sure all parameters are sharded.
        for n, p in self.named_parameters():
            if not hasattr(p, "_is_sharded"):
                raise RuntimeError(
                    f"found unsharded parameter: {n} ; {p.size()}")
        self._reset_lazy_init()

        # Enum to indicate if we're in the forward/backward pass, idle, etc.
        self.training_state = TrainingState_.IDLE

        # Flag to guard against preparing gradients multiple times per backward pass.
        self._pre_backward_hook_has_run = False

        # If specified, offload parameter shard to CPU.
        if self.cpu_offload.offload_params:
            for p in self.params:
                self._offload_to_cpu(p)
Exemplo n.º 16
0
    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (
            ShardedTensor
        )
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned()
        )
        current_rank = dist.get_rank(process_group)
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []
        local_tensor = None
        local_metadata = None
        tensors_to_scatter = [None] * dist.get_world_size(process_group)

        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
        chunks = len(self.placements)
        split_size = get_split_size(sharding_dim_size, chunks)
        scatter_shape = list(tensor.size())
        scatter_shape[self.dim] = split_size  # type: ignore[index]

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
            if current_rank == src_rank:
                # Reshape to get shard for this rank and we don't want autograd
                # recording here for the narrow op and 'local_shard' should be a
                # leaf variable in the autograd graph.
                narrowed_tensor = narrow_tensor(tensor, shard_meta)
                if shard_meta.shard_sizes[self.dim] < split_size:  # type: ignore[index]
                    # for the last shard that might be smaller to other shards
                    # resize the narrowed tensor to the same size and use it for
                    # the scatter collective as dist.scatter requires same size
                    # inputs on every rank
                    tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
                else:
                    tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()

                tensors_to_scatter[rank] = tensor_to_scatter

            if current_rank == rank:
                local_tensor = torch.empty(
                    scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
                local_metadata = shard_meta

        # each rank should have local_tensor and local_metadata initialized if we build
        # the metadata list in a correct way.
        assert local_tensor is not None
        assert local_metadata is not None

        # Scatter the shards to all ranks in the pg
        # scatter takes the global rank as ``src``
        src_for_scatter = src_rank
        if process_group is not None and process_group is not distributed_c10d._get_default_group():
            src_for_scatter = distributed_c10d._get_global_rank(process_group, src_for_scatter)

        dist.scatter(
            local_tensor,
            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
            src=src_for_scatter,
            group=process_group
        )

        if list(local_tensor.size()) != local_metadata.shard_sizes:
            # detach again after receiving to ensure local shards remain a leaf node
            local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()

        # Sync requires_grad to local_shard.
        local_tensor.requires_grad = tensor.requires_grad

        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))

        st = ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards,
            tensor_meta,
            process_group=process_group)

        # Manually set sharding_spec
        st._sharding_spec = self

        return st