Exemplo n.º 1
0
def _reshard_flatten_tensor(
    input_tensor: ShardedTensor,
    output_spec: ShardingSpec,
    world_size: int,
    my_rank: int,
    device: torch.device,
    process_group: Optional[dist.ProcessGroup],
) -> torch.Tensor:
    """
    Resharded a sharded flatten tensor, this is used by FSDP to do sharded
    state_dict. But the functionaility is not supported by ShardedTensor.
    This API is designed to be used for FSDP; therefore this API supports only
    1-D ShardedTensor (hence the naming, reshard_flatten_tensor).

    This API uses the ChunkShardingSpec and EnumerableShardingSpec from
    torch.distributed.sharding_spec but ignores the placement field in
    ChunkShardingSpec, as the placement requires the callees understand the
    number of GPUs per node. The API simply uses the semantics of the sharding
    specs.

    Args:
        input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D.
        output_spec (ShardingSpec): the sharding spect for the output tensor.
        world_size (int): total trainer count.
        my_rank (int): the rank for this trainer.

    Returns:
        The local shard for the new ShardedTensor.
    """

    input_spec = input_tensor.sharding_spec()
    size = input_tensor.size()
    if isinstance(size, int):
        raise ValueError("The input tensor has no dimensions.")
    tensor_numel = size.numel()
    input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel,
                                              world_size)
    output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel,
                                               world_size)
    input_split_sizes, output_split_sizes = _offsets_to_split_sizes(
        input_offsets, output_offsets, tensor_numel, world_size, my_rank)
    output_size = sum(output_split_sizes)
    local_shard = torch.empty(output_size,
                              dtype=input_tensor.dtype,
                              device=device)
    dist.all_to_all_single(
        local_shard,
        input_tensor.local_shards()[0].tensor,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=process_group,
    )
    return local_shard
Exemplo n.º 2
0
 def sharded_tensor_op_on_local_shards(types,
                                       args=(),
                                       kwargs=None,
                                       pg=None):
     st = args[0]
     sharding_spec = st.sharding_spec()
     _chunk_sharding_spec_check(sharding_spec, op)
     if len(st.local_shards()) != 1:
         raise TypeError(
             f"torch function '{op.__name__}', with args: {args} and "
             f"kwargs: {kwargs} only supported for single local tensor!")
     st_size = st.size()
     if customized_func:
         local_tensor, sharding_spec, st_size = customized_func(
             args, kwargs, pg)
     else:
         args = (st.local_tensor(), *args[1:])
         local_tensor = op(*args, **kwargs)
     return ShardedTensor._init_from_local_tensor(
         local_tensor.contiguous(),
         sharding_spec,
         st_size,  # type: ignore[arg-type]
         process_group=pg,
         init_rrefs=st._init_rrefs,
     )
Exemplo n.º 3
0
    def binary_math_op(types, args=(), kwargs=None, pg=None):
        """
        Handles ``__torch_function__`` dispatch for the binary math ops
        such as `torch.add`, `torch.mul`, `torch.div`, etc.
        This method computes on ShardedTensor
        """
        if len(args) != 2:
            raise ValueError("Only support binary math op on ShardedTensor for now!")
        lhs = args[0]
        rhs = args[1]
        # Validate types
        if isinstance(lhs, ShardedTensor) and isinstance(rhs, ShardedTensor):
            lhs_spec = lhs.sharding_spec()
            rhs_spec = rhs.sharding_spec()
            _chunk_sharding_spec_check(lhs_spec, op)
            _chunk_sharding_spec_check(rhs_spec, op)

            if lhs.size() == rhs.size() and lhs_spec.dim == rhs_spec.dim:  # type: ignore[attr-defined]
                # perform local element-wise math op
                res = op(lhs.local_tensor(), rhs.local_tensor())
                return ShardedTensor._init_from_local_tensor(
                    res,
                    lhs_spec,
                    lhs.size(),  # type: ignore[arg-type]
                    process_group=pg)
            else:
                raise RuntimeError("Implicit broadcasting not supported yet!")
        else:
            # Try dispatch to ShardingSpec agnostic ops.
            return binary_math_op_impl(op, types, args, kwargs, pg)
Exemplo n.º 4
0
def _prepare_sharded_tensor_read(
    fqn: str,
    storage_metadata: Dict[MetadataIndex, str],
    metadata: TensorStorageMetadata,
    sharded_tensor_out: ShardedTensor
) -> List[TensorReadRequest]:
    """
    Prepare sharded tensor read.

    Args:
        fqn: The FQN of ``sharded_tensor`` in the state_dict.
        storage_metadata: Dictionary describing checkpoint storage.
        metadata: Metadata describing the persisted sharded tensor. Normally,
                  this is generated by func::`_prepare_sharded_tensor_write`.
        sharded_tensor_out: The ShardedTensor being read.

    Returns:
        A list of class::`TensorReadRequest`. When fullfilled,
        `sharded_tensor_out`'s local shards load from the persisted sharded
        tensor.
    """
    return _prepare_generic_tensor_read(
        fqn,
        metadata.chunks,
        sharded_tensor_out.local_shards(),
        storage_metadata)
Exemplo n.º 5
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    gathered_inputs = all_gather(input, group=pg)
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)
    results = []
    for i, inp in enumerate(gathered_inputs):
        results.append(inp.matmul(local_shard_t) + local_bias)
    # When the local result only has one dimension, we need to make sure
    # it does not shard by dim 0. So reshard can work properly.
    if results[0].dim() == 1:  # type: ignore[attr-defined]
        result = torch.stack(results)  # type: ignore[arg-type]
    else:
        result = torch.cat(results)  # type: ignore[arg-type]
    st_size = list(result.size())
    st_size[-1] = weight.size(0)
    new_sharding_spec = ChunkShardingSpec(
        dim=-1, placements=weight.sharding_spec().placements)
    return ShardedTensor._init_from_local_tensor(
        result,
        new_sharding_spec,
        *st_size,  # type: ignore[arg-type]
        process_group=pg,
    )
Exemplo n.º 6
0
def _compute_sharded_tensor_md(
        tensor: ShardedTensor,
        shard_to_storage_key: Dict[str, str]) -> ShardedTensorStorageMetadata:
    smd = []
    for shard_md in tensor.metadata().shards_metadata:
        shard_storage_key = shard_to_storage_key[_get_shard_key(shard_md)]

        one_smd = ShardStorageMetadata(
            shard_metadata=shard_md,
            storage_key=shard_storage_key,
        )
        smd.append(one_smd)

    return ShardedTensorStorageMetadata(
        tensor_metadata=tensor.metadata(),
        storage_metadata=smd,
    )
Exemplo n.º 7
0
def _all_gather_sharded_tensor(
        sharded_tensor: ShardedTensor,
        pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor:
    if pg is None:
        pg = distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    shards = sharded_tensor.local_shards()
    local_tensor = shards[0].tensor.flatten()
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
    chunk_size = math.ceil(
        dim_0_size / world_size) * tensor_numel // dim_0_size
    num_padding = chunk_size - local_tensor.numel()
    if num_padding > 0:
        local_tensor = F.pad(local_tensor, [0, num_padding])
    tensor = torch.empty(chunk_size * world_size,
                         dtype=local_tensor.dtype).cuda()
    dist._all_gather_base(tensor, local_tensor, group=pg)
    return tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
Exemplo n.º 8
0
def _prepare_sharded_tensor_read(
        metadata: ShardedTensorStorageMetadata,
        sharded_tensor_out: ShardedTensor) -> List[TensorReadRequest]:
    """
    Prepare sharded tensor read.

    Args:
        metadata: Metadata describing the persisted sharded tensor. Normally,
                  this is generated by func::`_prepare_sharded_tensor_write`.
        sharded_tensor_out: The dest sharded tensor.

    Returns:
        A list of class::`TensorReadRequest`. When fullfilled,
        `sharded_tensor_out`'s local shards load from the persisted sharded
        tensor.
    """
    read_reqs = []
    # this is a naive quadratic algo that can be optimized later
    for shard in sharded_tensor_out.local_shards():
        # scan all mds looking for chunks
        for storage_md in metadata.storage_metadata:
            shard_md_from_storage = storage_md.shard_metadata

            # do they overlap?
            if not _check_shard_metadata_pair_overlap(shard.metadata,
                                                      shard_md_from_storage):
                continue

            storage_key = storage_md.storage_key
            target_tensor = shard.tensor.detach()
            offsets = []
            lengths = []
            for (
                    dim,
                    offset_for_saved_tensor,
                    offset_for_current_tensor,
                    length,
            ) in _shards_get_overlap_region_wrt_saved_tensor(
                    saved_shard=shard_md_from_storage,
                    current_shard=shard.metadata):
                # Note that we do NOT want to make any tensor copy.
                # all operation must be view only
                target_tensor = torch.narrow(target_tensor, dim,
                                             offset_for_current_tensor, length)
                offsets.append(offset_for_saved_tensor)
                lengths.append(length)

            read_reqs.append(
                TensorReadRequest(
                    tensor=target_tensor,
                    storage_key=storage_key,
                    offsets=tuple(offsets),
                    lengths=tuple(lengths),
                ))
    return read_reqs
Exemplo n.º 9
0
def _prepare_sharded_tensor_write(
    fqn: str,
    sharded_tensor: ShardedTensor,
    storage_key: str,
    storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], TensorStorageMetadata, Dict[MetadataIndex, str]]:
    """
    Prepare sharded tensor write.

    Args:
        fqn: The FQN of ``sharded_tensor`` in the state_dict.
        sharded_tensor: The sharded tensor to persist.
        storage_key: The identifier for `sharded_tensor`.
        storage_key_to_fqn: dict used to produce storage keys
    Returns:
        A 3-element tuple with the following values:
            List of ``TensorWriteRequest`` for the tensor
            Metadada describing the tensor.
            Dictionary describing storage information for this tensor

    NB `storage_key` is used to compose the key names of the local shards.
    """
    write_requests = []
    shard_to_storage_key: Dict[str, str] = dict()
    storage_md = {}

    for shard_md in sharded_tensor.metadata().shards_metadata:
        shard_storage_key = _get_shard_storage_key(storage_key, shard_md, storage_key_to_fqn)
        shard_to_storage_key[_get_shard_key(shard_md)] = shard_storage_key
        storage_md[MetadataIndex(fqn, shard_md.shard_offsets)] = shard_storage_key

    for shard in sharded_tensor.local_shards():
        tensor = shard.tensor.detach()
        shard_storage_key = shard_to_storage_key[_get_shard_key(shard.metadata)]

        wr = TensorWriteRequest(
            tensor=_trim(tensor),
            storage_key=shard_storage_key,
        )
        write_requests.append(wr)
    return write_requests, _compute_sharded_tensor_md(sharded_tensor), storage_md
Exemplo n.º 10
0
def _create_chunk_sharded_tensor(
    tensor: torch.Tensor,
    rank: int,
    world_size: int,
    device_per_node: int,
    pg: dist.ProcessGroup,
) -> ShardedTensor:
    """
    Shard a tensor to chunks along the first dimension. The local rank will gets its
    corresponding chunk as the local shard to create a ShardedTensor.
    """
    chunks = tensor.chunk(world_size, dim=0)
    if len(chunks) > rank:
        local_shard = chunks[rank].clone()
        offsets = [0 for _ in tensor.size()]
        offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank
        local_shards = [
            Shard.from_tensor_and_offsets(local_shard, offsets, rank)
        ]
    else:
        local_shards = []

    # Create a ShardedTensor without invoking communication.
    chunk_sizes = [list(chunk.size()) for chunk in chunks]
    dim0_offsets = [0] + list(
        itertools.accumulate([chunk_size[0]
                              for chunk_size in chunk_sizes]))[:-1]
    offsets = [0] * (len(chunk_sizes[0]) - 1)
    chunk_offsets = [[d0] + offsets for d0 in dim0_offsets]
    placements = [
        f"rank:{r}/cuda:{r % device_per_node}" for r in range(len(chunk_sizes))
    ]
    assert len(chunk_sizes) == len(chunk_offsets) == len(placements)
    shard_metadata = [
        ShardMetadata(offset, size,
                      placement) for offset, size, placement in zip(
                          chunk_offsets, chunk_sizes, placements)
    ]
    sharded_tensor_metadata = ShardedTensorMetadata(
        shards_metadata=shard_metadata,
        size=tensor.size(),
        tensor_properties=TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=False,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned(),
        ))
    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards,
        sharded_tensor_metadata=sharded_tensor_metadata,
        process_group=pg)
Exemplo n.º 11
0
def _prepare_sharded_tensor_write(
    sharded_tensor: ShardedTensor, storage_key: str,
    storage_key_to_fqn: Dict[str, str]
) -> Tuple[List[TensorWriteRequest], ShardedTensorStorageMetadata]:
    """
    Prepare sharded tensor write.

    Args:
        sharded_tensor: The sharded tensor to persist.
        storage_key: The identifier for `sharded_tensor`.
        storage_key_to_fqn: dict used to produce storage keys

    Returns:
        Write requests for persisting the sharded tensor, and metadata
        describing the persisted sharded tensor.

    NB `storage_key` is used to compose the key names of the local shards.

    """
    write_requests = []
    shard_to_storage_key: Dict[str, str] = dict()

    for shard_md in sharded_tensor.metadata().shards_metadata:
        shard_storage_key = _get_shard_storage_key(storage_key, shard_md,
                                                   storage_key_to_fqn)
        shard_to_storage_key[_get_shard_key(shard_md)] = shard_storage_key

    for shard in sharded_tensor.local_shards():
        tensor = shard.tensor.detach()
        shard_storage_key = shard_to_storage_key[_get_shard_key(
            shard.metadata)]

        wr = TensorWriteRequest(
            tensor=_trim(tensor),
            storage_key=shard_storage_key,
        )
        write_requests.append(wr)
    return write_requests, _compute_sharded_tensor_md(sharded_tensor,
                                                      shard_to_storage_key)
Exemplo n.º 12
0
def _all_gather_sharded_tensor(
        sharded_tensor: ShardedTensor,
        pg: Optional[dist.ProcessGroup] = None) -> torch.Tensor:
    if pg is None:
        pg = distributed_c10d._get_default_group()
    world_size = dist.get_world_size(pg)
    shards = sharded_tensor.local_shards()
    dim_0_size = sharded_tensor.size()[0]  # type: ignore[index]
    tensor_numel = sharded_tensor.size().numel()  # type: ignore[union-attr]
    chunk_size = math.ceil(
        dim_0_size / world_size) * tensor_numel // dim_0_size
    cuda_device = torch.device("cuda", torch.cuda.current_device())
    if shards:
        local_tensor = shards[0].tensor.flatten()
        if not local_tensor.is_cuda:
            move_to_cpu = torch.ones(1, device=cuda_device)
            local_tensor = local_tensor.cuda()
        else:
            move_to_cpu = torch.zeros(1, device=cuda_device)
        num_padding = chunk_size - local_tensor.numel()
        if num_padding > 0:
            local_tensor = F.pad(local_tensor, [0, num_padding])
    else:
        local_tensor = torch.zeros(chunk_size,
                                   dtype=sharded_tensor.dtype,
                                   device=cuda_device)
        move_to_cpu = torch.zeros(1, device=cuda_device)

    tensor = torch.empty(
        chunk_size * world_size,
        dtype=local_tensor.dtype,
        device=cuda_device,
    )
    dist._all_gather_base(tensor, local_tensor, group=pg)

    tensor = tensor.narrow(0, 0, tensor_numel).reshape(sharded_tensor.size())
    return tensor
Exemplo n.º 13
0
def _init_sharded_tensor_from_local_result(
    sharded_tensor,
    local_result,
    tensor_shard_dim,
    result_shard_dim,
    world_size,
    pg,
):
    """
    Given a sharded tensor and local_result from an op on top of it. We want
    to create a new sharded tensor from the local_result so that the the next
    op can be performed on the basis of the new sharded tensor. This can seen
    as the last step of the first phase of the Megatron-LM style model(tensor)
    parallelism.

    Args:
        sharded_tensor: Sharded tensor which the op was performed on.
        local_result: A tensor which is from the op performed on the local_shard of
            the sharded_tensor.
        tensor_shard_dim: Dim which the tensor is sharded on.
        result_shard_dim: Dim which the new sharded tensor will be sharded on.
        world_size: number of ranks.
        pg (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Return:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    sharded_weight_metadata = copy.deepcopy(
        sharded_tensor.local_shards()[0].metadata)
    current_offsets = [0] * local_result.dim()
    current_offsets[result_shard_dim] = sharded_weight_metadata.shard_offsets[
        tensor_shard_dim]
    global_size = list(local_result.size())
    global_size[result_shard_dim] = sharded_tensor.size(tensor_shard_dim)
    local_shard_metadata = ShardMetadata(
        shard_offsets=current_offsets,
        shard_sizes=list(local_result.size()),
        placement=sharded_weight_metadata.placement,
    )
    local_shards = [Shard(local_result, local_shard_metadata)]
    new_st = ShardedTensor._init_from_local_shards(local_shards,
                                                   tuple(global_size),
                                                   process_group=pg)

    # Manually set sharding_spec
    new_st._sharding_spec = copy.deepcopy(sharded_tensor._sharding_spec)
    new_st._sharding_spec.dim = result_shard_dim
    return new_st
Exemplo n.º 14
0
def _compute_sharded_tensor_md(
        tensor: ShardedTensor,
        shard_to_storage_key: Dict[str, str]) -> ShardedTensorStorageMetadata:
    smd = []
    for shard_md in tensor.metadata().shards_metadata:
        shard_storage_key = shard_to_storage_key[_get_shard_key(shard_md)]

        shard_size = 1
        for d in shard_md.shard_sizes:
            shard_size *= d

        # not particularly great
        storage_size = shard_size * _get_sharded_tensor_element_size(tensor)

        one_smd = ShardStorageMetadata(
            shard_metadata=shard_md,
            storage_key=shard_storage_key,
        )
        smd.append(one_smd)

    return ShardedTensorStorageMetadata(
        tensor_metadata=tensor.metadata(),
        storage_metadata=smd,
    )
Exemplo n.º 15
0
def tensor_deepcopy(types, args=(), kwargs=None, pg=None):
    # NOTE: we directly implement deepcopy magic method
    # instead of using the default tensor.__deepcopy__
    # and implement clone(). This is because the default
    # tensor deepcopy copies every attribute, but the
    # process_group in ShardedTensor cannot be deep copied.
    self_st = args[0]
    # Validate types
    if not isinstance(self_st, ShardedTensor):
        raise TypeError("input needs to be a ShardedTensor")

    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards=copy.deepcopy(self_st.local_shards()),
        sharded_tensor_metadata=copy.deepcopy(self_st.metadata()),
        process_group=self_st._process_group,
        init_rrefs=self_st._init_rrefs)
Exemplo n.º 16
0
 def elementwise_op(types, args=(), kwargs=None, pg=None):
     """
     Handles ``__torch_function__`` dispatch for the elementwise op such
     as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
     This method computes on either a normal tensor or a sharded tensor.
     """
     input = args[0]
     # Validate types
     if not isinstance(input, ShardedTensor):
         raise TypeError("input needs to be a ShardedTensor")
     local_shards_new = []
     for local_shard in input.local_shards():
         local_shards_new.append(
             Shard(op(local_shard.tensor), local_shard.metadata))
     return ShardedTensor._init_from_local_shards_and_global_metadata(
         local_shards_new, input.metadata(), process_group=pg)
Exemplo n.º 17
0
    def sharded_chunk(types, args=(), kwargs=None, pg=None):
        """
        Handles ``__torch_function__`` dispatch for the chunk op.
        If we chunk by the non-sharding dim, we just directly chunk the
        local tensor and create a list of sharded tensor based on them.

        Warnings: Chunk by the sharding dim is not supported.

        Args: same as ``torch.chunk``.

        Return:
            List[ShardedTensor]: Chunk results as a list of ShardedTensor.
        """
        st = args[0]
        chunk_num = args[1]
        dim = kwargs.get("dim")
        dim = dim if dim else 0

        # Validate types
        if not isinstance(st, ShardedTensor):
            raise TypeError(
                f"torch function '{op.__name__}', with args: {args} and "
                f"kwargs: {kwargs} are called for non ShardedTensor!")
        spec = st.sharding_spec()
        if not isinstance(spec, ChunkShardingSpec):
            raise NotImplementedError(
                "Only ChunkShardingSpec is supported for chunk.")
        if spec.dim == dim or st.dim() + spec.dim == dim or st.dim(
        ) + dim == spec.dim:  # type: ignore[operator]
            raise NotImplementedError(
                "Chunk by sharding dim is not supported.")

        local_tensor = st.local_tensor()
        st_size = st.size()
        dim = dim if dim > 0 else st.dim() + dim
        results = []
        for chunk_tensor in local_tensor.chunk(chunk_num, dim=dim):
            new_st_size = (*st_size[:dim], chunk_tensor.size(dim),
                           *st_size[dim + 1:])  # type: ignore[index]
            results.append(
                ShardedTensor._init_from_local_tensor(
                    chunk_tensor.contiguous(),
                    st.sharding_spec(),
                    new_st_size,
                    process_group=pg,
                ))
        return results
Exemplo n.º 18
0
    def shard(self,
              tensor: torch.Tensor,
              src_rank: int = 0,
              process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (ShardedTensor)
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned())
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []

        current_rank = dist.get_rank(process_group)
        # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
        dist.broadcast(tensor, src=src_rank, group=process_group)

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(
                process_group, shard_meta.placement)
            if rank == current_rank:
                shard_offsets = shard_meta.shard_offsets
                shard_sizes = shard_meta.shard_sizes
                local_tensor = tensor
                for idx, (offset,
                          size) in enumerate(zip(shard_offsets, shard_sizes)):
                    if size < tensor.size(idx):
                        # Reshape to get shard for this rank and we don't want autograd
                        # recording here for the narrow op and 'local_shard' should be a
                        # leaf variable in the autograd graph.
                        local_tensor = local_tensor.narrow(
                            idx, shard_offsets[idx],
                            shard_sizes[idx]).clone().detach().contiguous()
                # Sync requires_grad to local_shard.
                local_tensor.requires_grad = tensor.requires_grad
                local_shards.append(
                    Shard(tensor=local_tensor, metadata=shard_meta))

        st = ShardedTensor._init_from_local_shards(local_shards,
                                                   tensor.size(),
                                                   process_group=process_group)
        # Manually set sharding_spec
        st._sharding_spec = self

        return st
Exemplo n.º 19
0
def sharded_softmax(types, args=(), kwargs=None):
    input = args[0]
    pg = input._process_group
    dim = kwargs['dim']
    sharding_dim = input.sharding_spec().dim
    ndims = len(input.size())
    if dim == sharding_dim or dim + ndims == sharding_dim or sharding_dim + ndims == dim:
        exp = torch.exp(input.local_tensor())
        exp_sum = exp.sum(dim=dim).unsqueeze(dim=dim)
        exp_sum = torch.distributed.nn.functional.all_reduce(exp_sum, group=pg)
        smax = torch.div(exp, exp_sum)
    else:
        smax = torch.nn.functional.softmax(input.local_tensor())
    return ShardedTensor._init_from_local_tensor(smax,
                                                 input.sharding_spec(),
                                                 input.size(),
                                                 process_group=pg)
Exemplo n.º 20
0
def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
    if index.offset is None:
        raise ValueError(
            f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
        )

    shards = tensor.local_shards()
    # index fast path
    if index.index is not None:
        if len(shards) > index.index and torch.Size(
                shards[index.index].metadata.shard_offsets) == index.offset:
            return shards[index.index]

    for shard in shards:
        if torch.Size(shard.metadata.shard_offsets) == index.offset:
            return shard
    raise ValueError(
        f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
Exemplo n.º 21
0
def _prepare_sharded_tensor_read(
        metadata: ShardedTensorStorageMetadata,
        sharded_tensor_out: ShardedTensor) -> List[TensorReadRequest]:
    """
    Prepare sharded tensor read.

    Args:
        metadata: Metadata describing the persisted sharded tensor. Normally,
                  this is generated by func::`_prepare_sharded_tensor_write`.
        sharded_tensor_out: The dest sharded tensor.

    Returns:
        A list of class::`TensorReadRequest`. When fullfilled,
        `sharded_tensor_out`'s local shards load from the persisted sharded
        tensor.
    """
    return _prepare_generic_tensor_read(metadata.storage_metadata,
                                        sharded_tensor_out.local_shards())
Exemplo n.º 22
0
 def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
     st = args[0]
     st_metadata = st.metadata()
     local_shards = st.local_shards()
     local_shards_new = []
     if customized_func:
         local_shards_new, st_metadata = customized_func(args, kwargs, pg)
     else:
         for local_shard in local_shards:
             args = (local_shard.tensor, *args[1:])
             local_shards_new.append(
                 Shard(op(*args, **kwargs), local_shard.metadata)
             )
     return ShardedTensor._init_from_local_shards_and_global_metadata(
         local_shards_new,
         st_metadata,
         process_group=pg,
         init_rrefs=st._init_rrefs,
         sharding_spec=st.sharding_spec()
     )
Exemplo n.º 23
0
    def elementwise_op(types, args=(), kwargs=None, pg=None):
        """
        Handles ``__torch_function__`` dispatch for the elementwise op such
        as ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``.
        This method computes on either a normal tensor or a sharded tensor.
        """
        input = args[0]
        # Validate types
        if not isinstance(input, ShardedTensor):
            raise TypeError("input needs to be a ShardedTensor")
        local_shards_new = []
        for local_shard in input.local_shards():
            local_shards_new.append(
                Shard(op(local_shard.tensor), local_shard.metadata))
        # TODO: After a new API for sharded tensor creation, we need to replace this.
        # https://github.com/pytorch/pytorch/issues/72092
        new_st = ShardedTensor._init_from_local_shards(local_shards_new,
                                                       input.size(),
                                                       process_group=pg)

        # Manually set sharding_spec
        new_st._sharding_spec = copy.deepcopy(input._sharding_spec)
        return new_st
Exemplo n.º 24
0
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
    shards_metadata = []
    local_shards = []
    for idx in range(0, world_size * shards_per_rank):
        shard_rank = idx // shards_per_rank
        shard_md = ShardMetadata(shard_offsets=[idx * shard_size],
                                 shard_sizes=[shard_size],
                                 placement=f"rank:{shard_rank}/cpu")
        shards_metadata.append(shard_md)
        if shard_rank == rank:
            shard = Shard.from_tensor_and_offsets(
                torch.rand(*shard_md.shard_sizes),
                shard_offsets=shard_md.shard_offsets,
                rank=rank)
            local_shards.append(shard)

    sharded_tensor_md = ShardedTensorMetadata(
        shards_metadata=shards_metadata,
        size=torch.Size([shard_size * len(shards_metadata)]),
        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)))

    return ShardedTensor._init_from_local_shards_and_global_metadata(
        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md)
 def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
     res = torch.zeros(tensor.shape,
                       device="cpu") if dist.get_rank() == 0 else None
     tensor.gather(out=res)
     return res
Exemplo n.º 26
0
def _shard_tensor(tensor: torch.Tensor,
                  sharding_spec: ShardingSpec,
                  src_rank=0,
                  process_group=None):
    """
    Given a :class:`torch.Tensor`, it shards that tensor according to the provided
    ``sharding_spec``. ``src_rank`` denotes the source rank which would be
    used as the ground truth of the data which would be scattered as shards
    across the rest of the ranks.

    Args:
        tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
            describing how to shard the Tensor.

    Keyword args:
        src_rank (int, optional): The source rank which is used as the ground truth of
            the data for the parameter that would be sharded and scattered
            across the rest of the ranks.
            Default: 0.
        process_group (ProcessGroup, optional): The process group to work on. If None,
            the default process group will be used.

    Returns:
        A :class:`ShardedTensor` sharded from the given tensor.

    .. warning::
        Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
        currently supported as the ``sharding_spec``.
    """
    if not isinstance(sharding_spec, ChunkShardingSpec):
        raise NotImplementedError('Only ChunkShardingspec is supported.')
    if not tensor.is_contiguous():
        raise ValueError('input tensor is not a contiguous Tensor')

    pg = process_group if process_group is not None else distributed_c10d._get_default_group(
    )
    world_size = dist.get_world_size(pg)
    rank = dist.get_rank(pg)

    # Validate src_rank and sharding_spec are same across all ranks.
    gathered_list = [None] * world_size
    dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)

    for idx, entry in enumerate(gathered_list):
        if src_rank != entry[0]:  # type: ignore[index]
            raise ValueError(
                f'src_rank={src_rank} on rank: {rank} does not '  # type: ignore[index]
                f'match with src_rank={entry[0]} on rank: {idx}')
        if sharding_spec != entry[1]:  # type: ignore[index]
            raise ValueError(
                f'sharding_spec={sharding_spec} on rank: {rank} does not '  # type: ignore[index]
                f'match with sharding_spec={entry[1]} on rank: {idx}')

    # Rearrange chunks according to placement.
    local_metadata = None
    current_offsets = [0] * len(tensor.size())
    shards_metadata = []
    sharding_dim_size = tensor.size(
        sharding_spec.dim)  # type: ignore[arg-type]
    split_size = get_split_size(sharding_dim_size, world_size)
    tensor_sizes = list(tensor.size())
    for idx, placement in enumerate(sharding_spec.placements):
        chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size,
                                                idx)
        shard_size = copy.deepcopy(tensor_sizes)
        shard_size[sharding_spec.dim] = chunked_dim_size  # type: ignore[index]

        shard_metadata = ShardMetadata(
            shard_offsets=copy.deepcopy(current_offsets),
            shard_sizes=shard_size,
            placement=placement,
        )
        shards_metadata.append(shard_metadata)

        if rank == placement.rank():  # type: ignore[union-attr]
            local_metadata = shard_metadata

        current_offsets[
            sharding_spec.dim] += chunked_dim_size  # type: ignore[index]

    # Scatter the shards (use broadcast since NCCL doesn't support scatter, this is very inefficient).
    dist.broadcast(tensor, src=src_rank, group=pg)

    # Reshape to get shard for this rank and we don't want autograd
    # recording here for the narrow op and 'local_shard' should be a
    # leaf variable in the autograd graph.
    local_shard = tensor.narrow(
        sharding_spec.dim,  # type: ignore[arg-type]
        local_metadata.shard_offsets[
            sharding_spec.dim],  # type: ignore[union-attr, arg-type, index]
        local_metadata.shard_sizes[
            sharding_spec.dim],  # type: ignore[union-attr, index]
    ).clone().detach().contiguous()

    # Sync requires_grad to local_shard.
    local_shard.requires_grad = tensor.requires_grad

    # Create ShardedTensor based on local shards.
    local_shards = [
        Shard(
            tensor=local_shard,
            metadata=local_metadata,  # type: ignore[arg-type]
        )
    ]

    return ShardedTensor._init_from_local_shards(local_shards,
                                                 tensor.size(),
                                                 process_group=pg)
Exemplo n.º 27
0
    def shard(self,
              tensor: torch.Tensor,
              src_rank: int = 0,
              process_group=None) -> "ShardedTensor":
        # relative imports to avoid circular dependency
        from torch.distributed._shard.sharded_tensor import (ShardedTensor)
        tensor_properties = sharded_tensor_meta.TensorProperties(
            dtype=tensor.dtype,
            layout=tensor.layout,
            requires_grad=tensor.requires_grad,
            memory_format=torch.contiguous_format,
            pin_memory=tensor.is_pinned())
        current_rank = dist.get_rank(process_group)
        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
        local_shards = []
        local_tensor = None
        local_metadata = None
        tensors_to_scatter = [None] * dist.get_world_size(process_group)

        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
        chunks = len(self.placements)
        split_size = get_split_size(sharding_dim_size, chunks)
        scatter_shape = list(tensor.size())
        scatter_shape[self.dim] = split_size  # type: ignore[index]

        for shard_meta in tensor_meta.shards_metadata:
            rank, device = _parse_and_validate_remote_device(
                process_group, shard_meta.placement)
            if current_rank == src_rank:
                # Reshape to get shard for this rank and we don't want autograd
                # recording here for the narrow op and 'local_shard' should be a
                # leaf variable in the autograd graph.
                narrowed_tensor = narrow_tensor(tensor, shard_meta)
                if shard_meta.shard_sizes[
                        self.dim] < split_size:  # type: ignore[index]
                    # for the last shard that might be smaller to other shards
                    # resize the narrowed tensor to the same size and use it for
                    # the scatter collective as dist.scatter requires same size
                    # inputs on every rank
                    tensor_to_scatter = narrowed_tensor.detach().clone(
                    ).resize_(scatter_shape)
                else:
                    tensor_to_scatter = narrowed_tensor.detach().clone(
                    ).contiguous()

                tensors_to_scatter[rank] = tensor_to_scatter

            if current_rank == rank:
                local_tensor = torch.empty(scatter_shape,
                                           dtype=tensor.dtype,
                                           layout=tensor.layout,
                                           device=device)
                local_metadata = shard_meta

        # each rank should have local_tensor and local_metadata initialized if we build
        # the metadata list in a correct way.
        assert local_tensor is not None
        assert local_metadata is not None

        # Scatter the shards to all ranks in the pg
        dist.scatter(local_tensor,
                     scatter_list=tensors_to_scatter
                     if current_rank == src_rank else None,
                     src=src_rank,
                     group=process_group)

        if list(local_tensor.size()) != local_metadata.shard_sizes:
            # detach again after receiving to ensure local shards remain a leaf node
            local_tensor = local_tensor.resize_(
                local_metadata.shard_sizes).detach()

        # Sync requires_grad to local_shard.
        local_tensor.requires_grad = tensor.requires_grad

        local_shards.append(Shard(tensor=local_tensor,
                                  metadata=local_metadata))

        st = ShardedTensor._init_from_local_shards_and_global_metadata(
            local_shards, tensor_meta, process_group=process_group)

        # Manually set sharding_spec
        st._sharding_spec = self

        return st
Exemplo n.º 28
0
    def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> "ShardedTensor":
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        from torch.distributed._shard.sharded_tensor.api import ShardedTensor

        if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
            raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
        if self._local_shard.is_complex():
            raise NotImplementedError("Only real partial tensor supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
        local_shard = self._local_shard
        # Add padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            padding = [0] * (local_shard.dim() * 2)
            padding[-1] = self._process_group.size() - chunk_mode_res
            local_shard = torch.nn.functional.pad(
                local_shard,
                tuple(padding),
                "constant",
                0,
            )
        current_rank = dist.get_rank(self._process_group)  # type: ignore[attr-defined]
        rank_idx = None
        rearrange_local_shards = False
        indices = [0] * self._process_group.size()
        for idx, placement in enumerate(resharding_spec.placements):  # type: ignore[attr-defined]
            if placement.rank() == current_rank:  # type: ignore[index, union-attr]
                rank_idx = idx  # type: ignore[attr-defined]
            if placement.rank() != idx:  # type: ignore[index, union-attr]
                rearrange_local_shards = True
            indices[placement.rank()] = idx  # type: ignore[index, union-attr]

        local_shards = local_shard.chunk(self._process_group.size(), dim=sharding_dim)
        if rearrange_local_shards:
            # Need to re-arrange original shard_dim of output_tensor_list.
            local_shards = [local_shards[idx] for idx in indices]  # type: ignore[call-overload]
        local_result = reduce_scatter(
            torch.empty_like(local_shards[0]),
            list(local_shards),
            op=self._reduce_op,
            group=self._process_group,
        )

        sharded_tensor_size = self._local_shard.size()
        # Remove padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            uneven_local_shards = self._local_shard.chunk(
                self._process_group.size(), dim=sharding_dim
            )
            expected_size = uneven_local_shards[rank_idx].size()  # type: ignore[index]
            if local_result.size() != expected_size:
                local_result = local_result.narrow(
                    sharding_dim,
                    0,
                    expected_size[sharding_dim],
                )
        return ShardedTensor._init_from_local_tensor(
            local_result,
            resharding_spec,
            sharded_tensor_size,
            process_group=self._process_group,
        )
Exemplo n.º 29
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    out_size = list(input.size())
    out_size[0] = input.size(0) * dist.get_world_size(pg)
    output = torch.empty(out_size, device=input.device)
    output = _all_gather_base(output, input, group=pg)

    # Adjust bias and perform local matmul.
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)

    if output.dim() == 1:
        output = output.view(dist.get_world_size(pg), -1)

    if output.dim() <= 2:
        # Use fused version if possible.
        result = torch.addmm(local_bias, output, local_shard_t)
    else:
        result = output.matmul(local_shard_t) + local_bias

    # Build ShardedTensor as result.
    st_size = list(result.size())
    st_size[-1] = weight.size(0)
    new_sharding_spec = ChunkShardingSpec(
        dim=-1, placements=weight.sharding_spec().placements)
    return ShardedTensor._init_from_local_tensor(
        result,
        new_sharding_spec,
        *st_size,  # type: ignore[arg-type]
        process_group=pg,
    )
Exemplo n.º 30
0
def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
    """
    Handles ``__torch_function__`` dispatch for the binary math ops
    such as `torch.add`, `torch.mul`, `torch.div`, etc.
    This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor
    """
    if len(args) != 2:
        raise ValueError(
            "Only support binary math op on ShardedTensor for now!")
    lhs = args[0]
    rhs = args[1]
    # Validate types
    if isinstance(lhs, ReplicatedTensor):
        assert isinstance(rhs, ShardedTensor)
        st_size = rhs.size()
        st_meta = rhs.local_shards()[0].metadata
        if st_size != lhs.size():
            # try to broadcast replicated tensor
            lhs = lhs.expand(st_size)

        replica_part = narrow_tensor(lhs, st_meta)
        res = op(replica_part, rhs.local_tensor())

        return ShardedTensor._init_from_local_tensor(
            res,
            rhs.sharding_spec(),
            rhs.size(),  # type: ignore[arg-type]
            process_group=pg)

    elif isinstance(rhs, ReplicatedTensor):
        assert isinstance(lhs, ShardedTensor)
        st_size = lhs.size()
        st_meta = lhs.local_shards()[0].metadata
        if st_size != rhs.size():
            # try to broadcast replicated tensor
            rhs = rhs.expand(st_size)

        replica_part = narrow_tensor(rhs, st_meta)
        res = op(lhs.local_tensor(), replica_part)
        return ShardedTensor._init_from_local_tensor(
            res,
            lhs.sharding_spec(),
            lhs.size(),  # type: ignore[arg-type]
            process_group=pg)

    elif isinstance(lhs, (int, float)):
        assert isinstance(rhs, ShardedTensor)
        res = op(lhs, rhs.local_tensor())
        return ShardedTensor._init_from_local_tensor(
            res,
            rhs.sharding_spec(),
            rhs.size(),  # type: ignore[arg-type]
            process_group=pg)

    elif isinstance(rhs, (int, float)):
        assert isinstance(lhs, ShardedTensor)
        res = op(lhs.local_tensor(), rhs)
        return ShardedTensor._init_from_local_tensor(
            res,
            lhs.sharding_spec(),
            lhs.size(),  # type: ignore[arg-type]
            process_group=pg)
    else:
        raise RuntimeError(
            f"torch function '{op.__name__}', with args: {args} and "
            f"kwargs: {kwargs} not supported yet for ShardedTensor!")