Example #1
0
def sharded_detach(args, kwargs, pg):
    self_st = args[0]
    detached_local_shards = [
        Shard(
            local_shard.tensor.detach(),
            metadata=copy.deepcopy(local_shard.metadata),
        ) for local_shard in self_st.local_shards()
    ]
    new_metadata = copy.deepcopy(self_st.metadata())
    new_metadata.tensor_properties.requires_grad = False
    return detached_local_shards, new_metadata
Example #2
0
    def test_reshard_flatten_tensor(self):
        def get_offsets(tensor, shard):
            if self.rank == 0:
                return [0]
            else:
                return [tensor.shape[0] - shard.shape[0]]

        tensor = self._create_tensor()

        shard = _reshard_flatten_tensor(
            self._create_local_chunk(tensor),
            self._create_enumerate_spec(tensor),
            self.world_size,
            self.rank,
            tensor.device,
            _get_default_group(),
        )
        offsets = [0] if self.rank == 0 else [tensor.shape[0] - shard.shape[0]]
        shard = Shard.from_tensor_and_offsets(shard, offsets, self.rank)
        uneven_sharded_tensor = init_from_local_shards([shard], tensor.numel())

        shard = _reshard_flatten_tensor(
            uneven_sharded_tensor,
            self._create_chunk_spec(),
            self.world_size,
            self.rank,
            tensor.device,
            _get_default_group(),
        )
        offsets = [0] if self.rank == 0 else [tensor.shape[0] - shard.shape[0]]
        shard = Shard.from_tensor_and_offsets(shard, offsets, self.rank)
        even_sharded_tensor = init_from_local_shards([shard], tensor.numel())

        output = torch.empty(tensor.shape).cuda() if self.rank == 0 else None
        even_sharded_tensor.gather(0, output)
        if self.rank == 0:
            self.assertEqual(tensor, output)
        output = torch.empty(tensor.shape).cuda() if self.rank == 0 else None
        uneven_sharded_tensor.gather(0, output)
        if self.rank == 0:
            self.assertEqual(tensor, output)
Example #3
0
def _create_chunk_sharded_tensor(
    tensor: torch.Tensor,
    rank: int,
    world_size: int,
    device_per_node: int,
    pg: dist.ProcessGroup,
) -> ShardedTensor:
    """
    Shard a tensor to chunks along the first dimension. The local rank will gets its
    corresponding chunk as the local shard to create a ShardedTensor.
    """
    chunks = tensor.chunk(world_size, dim=0)
    if len(chunks) > rank:
        local_shard = chunks[rank].clone()
        offsets = [0 for _ in tensor.size()]
        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
        local_shards = [
            Shard.from_tensor_and_offsets(local_shard, offsets, rank)
        ]
    else:
        local_shards = []

    # Create a ShardedTensor without invoking communication.
    chunk_sizes = [list(chunk.size()) for chunk in chunks]
    dim0_offsets = [0] + list(
        itertools.accumulate([chunk_size[0]
                              for chunk_size in chunk_sizes]))[:-1]
    offsets = [0] * (len(chunk_sizes[0]) - 1)
    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
    placements = [
        f"rank:{r}/cuda:{r % device_per_node}" for r in range(len(chunk_sizes))
    ]
    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
    shard_metadata = [
        ShardMetadata(offset, size,
                      placement) for offset, size, placement in zip(
                          chunk_offsets, chunk_sizes, placements)
    ]
    sharded_tensor_metadata = ShardedTensorMetadata(
        shards_metadata=shard_metadata,
        size=tensor.size(),
        tensor_properties=TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=False,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned(),
        ))
    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards,
        sharded_tensor_metadata=sharded_tensor_metadata,
        process_group=pg)
Example #4
0
def sharded_clone(args, kwargs, pg):
    self_st = args[0]
    desire_memory_format = kwargs.get("memory_format", None)
    if desire_memory_format and desire_memory_format != torch.preserve_format:
        raise RuntimeError(
            "Only support torch.preserve_format for ShardedTensor!")
    cloned_local_shards = [
        Shard(
            local_shard.tensor.clone(memory_format=desire_memory_format),
            metadata=copy.deepcopy(local_shard.metadata),
        ) for local_shard in self_st.local_shards()
    ]
    new_metadata = copy.deepcopy(self_st.metadata())
    return cloned_local_shards, new_metadata
Example #5
0
def _init_sharded_tensor_from_local_result(
    sharded_tensor,
    local_result,
    tensor_shard_dim,
    result_shard_dim,
    world_size,
    pg,
):
    """
    Given a sharded tensor and local_result from an op on top of it. We want
    to create a new sharded tensor from the local_result so that the the next
    op can be performed on the basis of the new sharded tensor. This can seen
    as the last step of the first phase of the Megatron-LM style model(tensor)
    parallelism.

    Args:
        sharded_tensor: Sharded tensor which the op was performed on.
        local_result: A tensor which is from the op performed on the local_shard of
            the sharded_tensor.
        tensor_shard_dim: Dim which the tensor is sharded on.
        result_shard_dim: Dim which the new sharded tensor will be sharded on.
        world_size: number of ranks.
        pg (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Return:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    sharded_weight_metadata = copy.deepcopy(
        sharded_tensor.local_shards()[0].metadata)
    current_offsets = [0] * local_result.dim()
    current_offsets[result_shard_dim] = sharded_weight_metadata.shard_offsets[
        tensor_shard_dim]
    global_size = list(local_result.size())
    global_size[result_shard_dim] = sharded_tensor.size(tensor_shard_dim)
    local_shard_metadata = ShardMetadata(
        shard_offsets=current_offsets,
        shard_sizes=list(local_result.size()),
        placement=sharded_weight_metadata.placement,
    )
    local_shards = [Shard(local_result, local_shard_metadata)]
    new_st = ShardedTensor._init_from_local_shards(local_shards,
                                                   tuple(global_size),
                                                   process_group=pg)

    # Manually set sharding_spec
    new_st._sharding_spec = copy.deepcopy(sharded_tensor._sharding_spec)
    new_st._sharding_spec.dim = result_shard_dim
    return new_st
 def elementwise_op(types, args=(), kwargs=None, pg=None):
     """
     Handles ``__torch_function__`` dispatch for the elementwise op such
     as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
     This method computes on either a normal tensor or a sharded tensor.
     """
     input = args[0]
     # Validate types
     if not isinstance(input, ShardedTensor):
         raise TypeError("input needs to be a ShardedTensor")
     local_shards_new = []
     for local_shard in input.local_shards():
         local_shards_new.append(
             Shard(op(local_shard.tensor), local_shard.metadata))
     return ShardedTensor._init_from_local_shards_and_global_metadata(
         local_shards_new, input.metadata(), process_group=pg)
Example #7
0
def sharded_type_as(args, kwargs, pg):
    """
    Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.

    Args: same as ``torch.Tensor.type_as``.

    Return:
        new_local_shards (List[Shard]): Local shards for the new sharded tensor.
        st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
    """
    st = args[0]
    tensor = args[1]
    if isinstance(tensor, ShardedTensor):
        tensor = tensor.local_tensor()
    new_local_shards = []
    for shard in st.local_shards():
        new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
    st_meta = copy.deepcopy(st._metadata)
    st_meta.tensor_properties.dtype = tensor.dtype
    return new_local_shards, st_meta
Example #8
0
 def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
     st = args[0]
     st_metadata = st.metadata()
     local_shards = st.local_shards()
     local_shards_new = []
     if customized_func:
         local_shards_new, st_metadata = customized_func(args, kwargs, pg)
     else:
         for local_shard in local_shards:
             args = (local_shard.tensor, *args[1:])
             local_shards_new.append(
                 Shard(op(*args, **kwargs), local_shard.metadata)
             )
     return ShardedTensor._init_from_local_shards_and_global_metadata(
         local_shards_new,
         st_metadata,
         process_group=pg,
         init_rrefs=st._init_rrefs,
         sharding_spec=st.sharding_spec()
     )
Example #9
0
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
    shards_metadata = []
    local_shards = []
    for idx in range(0, world_size * shards_per_rank):
        shard_rank = idx // shards_per_rank
        shard_md = ShardMetadata(shard_offsets=[idx * shard_size],
                                 shard_sizes=[shard_size],
                                 placement=f"rank:{shard_rank}/cpu")
        shards_metadata.append(shard_md)
        if shard_rank == rank:
            shard = Shard.from_tensor_and_offsets(
                torch.rand(*shard_md.shard_sizes),
                shard_offsets=shard_md.shard_offsets,
                rank=rank)
            local_shards.append(shard)

    sharded_tensor_md = ShardedTensorMetadata(
        shards_metadata=shards_metadata,
        size=torch.Size([shard_size * len(shards_metadata)]),
        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)))

    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md)
Example #10
0
    def elementwise_op(types, args=(), kwargs=None, pg=None):
        """
        Handles ``__torch_function__`` dispatch for the elementwise op such
        as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
        This method computes on either a normal tensor or a sharded tensor.
        """
        input = args[0]
        # Validate types
        if not isinstance(input, ShardedTensor):
            raise TypeError("input needs to be a ShardedTensor")
        local_shards_new = []
        for local_shard in input.local_shards():
            local_shards_new.append(
                Shard(op(local_shard.tensor), local_shard.metadata))
        # TODO: After a new API for sharded tensor creation, we need to replace this.
        # https://github.com/pytorch/pytorch/issues/72092
        new_st = ShardedTensor._init_from_local_shards(local_shards_new,
                                                       input.size(),
                                                       process_group=pg)

        # Manually set sharding_spec
        new_st._sharding_spec = copy.deepcopy(input._sharding_spec)
        return new_st
Example #11
0
 def _create_local_chunk(self, tensor):
     chunk = tensor.chunk(2)[self.rank]
     offsets = [0] if self.rank == 0 else [tensor.shape[0] - chunk.shape[0]]
     shard = Shard.from_tensor_and_offsets(chunk, offsets, self.rank)
     return init_from_local_shards([shard], tensor.numel())
Example #12
0
def _shard_tensor(tensor: torch.Tensor,
                  sharding_spec: ShardingSpec,
                  src_rank=0,
                  process_group=None):
    """
    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Returns:
        A :class:`ShardedTensor` sharded from the given tensor.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise NotImplementedError('Only ChunkShardingspec is supported.')
    if not tensor.is_contiguous():
        raise ValueError('input tensor is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    return ShardedTensor._init_from_local_shards(local_shards,
                                                 tensor.size(),
                                                 process_group=pg)
Example #13
0
def shard_parameter(module: torch.nn.Module,
                    param_name: str,
                    sharding_spec: ShardingSpec,
                    src_rank=0,
                    process_group=None):
    """
    Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
    module, it shards that parameter according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    This method replaces ``module.param_name`` with a
    :class:`torch.distributed._shard.sharded_tensor.ShardedTensor`

    Args:
        module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
        param_name (str): Name of the parameter of ``module`` that needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    # Perform some validation first.
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise ValueError('Only ChunkShardingspec is supported.')

    if not hasattr(module, param_name):
        raise ValueError(
            f'module: {module} does not have parameter with name: {param_name}'
        )

    tensor = getattr(module, param_name)
    if not isinstance(tensor, torch.Tensor):
        raise ValueError(
            f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}'
        )

    if not tensor.is_contiguous():
        raise ValueError(f'param: {param_name} is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    st = ShardedTensor._init_from_local_shards(local_shards,
                                               tensor.size(),
                                               process_group=pg)

    # Manually set sharding_spec
    st._sharding_spec = sharding_spec

    # Replace param with ShardedTensor.

    # Need to delete the attribute first since param_name might be
    # torch.nn.Parameter and can't be replaced with ShardedTensor which is
    # not torch.nn.Parameter.
    delattr(module, param_name)

    # Now we can set the attribute appropriately.
    setattr(module, param_name, st)