Esempio n. 1
0
    def reshard(self, resharding_spec: ShardingSpec) -> ShardedTensor:
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        if not isinstance(resharding_spec, ChunkShardingSpec):
            raise NotImplementedError(
                "Only ChunkShardingSpec supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        if self.local_shard.size(
                sharding_dim) % self.process_group.size() != 0:
            raise ValueError(
                'World size need to divide the length of the dimension.')
        if self.local_shard.is_complex():
            raise NotImplementedError(
                "Only real partial tensor supported for reshard.")

        local_shards = self.local_shard.chunk(self.process_group.size(),
                                              dim=sharding_dim)
        local_result = reduce_scatter(torch.empty_like(local_shards[0]),
                                      list(local_shards),
                                      op=self.reduce_op)

        sharded_tensor_size = self.local_shard.size()
        current_offsets = [0] * len(local_result.size())
        shards = []
        rank = self.process_group.rank()
        for idx, placement in enumerate(
                resharding_spec.placements):  # type: ignore[attr-defined]
            if rank == placement.rank():  # type: ignore[union-attr]
                local_metadata = ShardMetadata(
                    shard_offsets=current_offsets,
                    shard_sizes=list(local_result.size()),
                    placement=placement,
                )
                shards.append(Shard(local_result, local_metadata))
                break
            current_offsets[sharding_dim] += local_result.size(
                sharding_dim)  # type: ignore[index]

        st = ShardedTensor._init_from_local_shards(
            shards,
            tuple(sharded_tensor_size),
            process_group=self.process_group)
        st._sharding_spec = copy.deepcopy(resharding_spec)

        return st
Esempio n. 2
0
    def reshard(self,
                resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor:
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
            raise NotImplementedError(
                "Only ChunkShardingSpec supported for reshard.")
        if self.local_shard.is_complex():
            raise NotImplementedError(
                "Only real partial tensor supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        chunk_mode_res = self.local_shard.size(
            sharding_dim) % self.process_group.size()
        # Add padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            padding = [0] * (len(self.local_shard.size()) * 2)
            padding[-1] = self.process_group.size() - chunk_mode_res
            self.local_shard = torch.nn.functional.pad(
                self.local_shard,
                tuple(padding),
                "constant",
                0,
            )

        local_shards = self.local_shard.chunk(self.process_group.size(),
                                              dim=sharding_dim)
        local_result = reduce_scatter(torch.empty_like(local_shards[0]),
                                      list(local_shards),
                                      op=self.reduce_op)

        sharded_tensor_size = self.local_shard.size()
        return ShardedTensor._init_from_local_tensor(
            local_result,
            resharding_spec,
            sharded_tensor_size,
            process_group=self.process_group,
        )
Esempio n. 3
0
    def reshard(self, resharding_spec: ShardingSpec) -> ShardedTensor:
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        if not isinstance(resharding_spec, ChunkShardingSpec):
            raise NotImplementedError(
                "Only ChunkShardingSpec supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        if self.local_shard.size(
                sharding_dim) % self.process_group.size() != 0:
            raise ValueError(
                'World size need to divide the length of the dimension.')
        if self.local_shard.is_complex():
            raise NotImplementedError(
                "Only real partial tensor supported for reshard.")

        local_shards = self.local_shard.chunk(self.process_group.size(),
                                              dim=sharding_dim)
        local_result = reduce_scatter(torch.empty_like(local_shards[0]),
                                      list(local_shards),
                                      op=self.reduce_op)

        sharded_tensor_size = self.local_shard.size()
        return ShardedTensor._init_from_local_tensor(
            local_result,
            resharding_spec,
            sharded_tensor_size,
            process_group=self.process_group,
        )
Esempio n. 4
0
def _handle_row_wise_sharding(input, world_size, weight, local_shard, max_norm,
                              norm_type, padding_idx, rank, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for embedding. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: row-wise shared local weight used for lookup.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
            Note: this will modify weight in-place.
        norm_type: The p in the p-norm to compute for the max_norm option.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
        rank: # of cuda process.
        pg: process group.

    Returns: final result of lookup.
    """
    if not isinstance(input, ReplicatedTensor):
        # allgather the inputs first for non Replicated Tensor.
        gather_inp = _all_gather_base_input(input, pg)
    else:
        gather_inp = input

    # Mask the input according to sharding spec.
    lookup_input, padding_idx, padding_row = _handle_row_wise_mask(
        gather_inp, padding_idx, weight, world_size, rank)

    # When input is a large tensor, the value of weight is changed.
    # This is a walk-around for now. GH issue: #81717
    if max_norm is not None:
        torch.nn.functional.embedding(
            torch.unique(lookup_input)[:-1],
            local_shard,
            padding_idx=padding_idx,
            max_norm=max_norm,
            norm_type=norm_type,
        )
        max_norm = None

    local_input_embeddings = torch.nn.functional.embedding(
        lookup_input,
        torch.cat([local_shard, padding_row]),
        padding_idx=padding_idx,
        max_norm=max_norm,
        norm_type=norm_type,
    )

    # TODO: Make the result a PartialTensor.
    if isinstance(input, ReplicatedTensor):
        return all_reduce(local_input_embeddings, group=pg)
    else:
        local_shards = local_input_embeddings.chunk(pg.size())
        return reduce_scatter(
            torch.empty_like(local_shards[0]),
            list(local_shards),
            group=pg,
        )
Esempio n. 5
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns: final result of linear operation.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform autograd enabled alltoall
    all_to_all_single(gathered_input,
                      input_t,
                      input_split_sizes=input_split_sizes,
                      group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    local_result = reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Esempio n. 6
0
def _handle_row_wise_sharding(
    input,
    world_size,
    weight,
    local_shard,
    offsets,
    per_sample_weights,
    mode,
    max_norm,
    norm_type,
    padding_idx,
    rank,
    pg,
):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for embeddingBag. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding_bag.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: row-wise shared local weight used for lookup.
        offsets: list of start positions of each bag for 1D input.
        per_sample_weights: weights for weighted sum mode.
        mode: aggregation method of each bag.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
            Note: this will modify weight in-place.
        norm_type: The p in the p-norm to compute for the max_norm option.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
            Note that the embedding vector at padding_idx is
            excluded from the reduction.
        rank: # of cuda process.
        pg: process group.

    Returns:
        gathered_output: final result of lookup and aggregation.
    """
    if not isinstance(input, ReplicatedTensor):
        if input.dim() > 1 and per_sample_weights is None:
            # allgather the inputs first for non Replicated Tensor.
            gather_inp = _all_gather_base_input(input, pg)
        else:
            (
                gathered_inputs,
                gathered_per_sample_weights,
                gathered_offsets,
            ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg)
            cat_dim = 0 if input.dim() != 1 else -1
            gather_inp = torch.cat(gathered_inputs, dim=cat_dim)
            if per_sample_weights is not None:
                per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim)
            offset_add = 0 if input.dim() > 1 else input.size(0)
            if offsets is not None:
                offsets_list = torch.cat(
                    [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())],
                    dim=cat_dim,
                )
    else:
        gather_inp = input

    # Mask the input according to sharding spec.
    lookup_input, padding_local, padding_row = _handle_row_wise_mask(
        gather_inp, padding_idx, weight, world_size, rank
    )
    if mode == "max":
        padding_row[:] = -float("Inf")

    # When input is a large tensor, the value of weight is changed.
    # This is a walk-around for now. GH issue: #81717.
    if max_norm is not None:
        torch.nn.functional.embedding_bag(
            torch.unique(lookup_input)[:-1],
            local_shard,
            offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long),
            mode=mode,
            per_sample_weights=None,
            max_norm=max_norm,
            norm_type=norm_type,
            padding_idx=padding_local,
        )
        max_norm = None
    result = torch.nn.functional.embedding_bag(
        lookup_input,
        torch.cat([local_shard, padding_row]),
        offsets=offsets_list if offsets is not None else offsets,
        mode=mode if mode != "mean" else "sum",
        per_sample_weights=per_sample_weights,
        max_norm=max_norm,
        norm_type=norm_type,
        padding_idx=padding_local,
    )

    op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX
    # TODO: Make the result a PartialTensor and move the the logic below there.
    if isinstance(input, ReplicatedTensor):
        result = all_reduce(result, op=op, group=pg)
    else:
        local_shards = result.chunk(pg.size())
        result = reduce_scatter(
            torch.empty_like(local_shards[0]),
            list(local_shards),
            op=op,
            group=pg,
        )

    # For Mean, we cannot do the division until very end because the sum of means
    # not equal to the mean of sum. (Divisor is different)
    if mode == "mean":
        if input.dim() > 1:
            padding_idx = padding_idx if padding_idx is not None else -1
            split_sizes = torch.sum(
                torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype
            )
        else:
            split_sizes = torch.cat(
                (
                    offsets[1 : offsets.size(0)] - offsets[0:-1],
                    (input.size(0) - offsets[-1]).unsqueeze(0),
                ),
                dim=-1,
            )
        return torch.div(result, split_sizes.unsqueeze(1))

    # Return the appropriate local result.
    return result
Esempio n. 7
0
    def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor:
        """
        The reshard happens in two steps logically:

        1. Aggregate all the shards of the partial tensor.
        2. Shard this tensor according to the provided spec.

        In reality, for the sake of performance, we consolidate all partial tensors
        across multiple ranks and covert to a sharded tensor in one step.

        Args:
            resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
                The specification describing how we reshard the aggregated local result.

        Returns:
            A :class:`ShardedTensor` filled with local aggregated result.
        """
        if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
            raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
        if self.local_shard.is_complex():
            raise NotImplementedError("Only real partial tensor supported for reshard.")
        sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
        chunk_mode_res = self.local_shard.size(sharding_dim) % self.process_group.size()
        local_shard = self.local_shard
        # Add padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            padding = [0] * (self.local_shard.dim() * 2)
            padding[-1] = self.process_group.size() - chunk_mode_res
            local_shard = torch.nn.functional.pad(
                self.local_shard,
                tuple(padding),
                "constant",
                0,
            )
        current_rank = dist.get_rank(self.process_group)  # type: ignore[attr-defined]
        rank_idx = None
        rearrange_local_shards = False
        indices = [0] * self.process_group.size()
        for idx, placement in enumerate(resharding_spec.placements):  # type: ignore[attr-defined]
            if placement.rank() == current_rank:  # type: ignore[index, union-attr]
                rank_idx = idx  # type: ignore[attr-defined]
            if placement.rank() != idx:  # type: ignore[index, union-attr]
                rearrange_local_shards = True
            indices[placement.rank()] = idx  # type: ignore[index, union-attr]

        local_shards = local_shard.chunk(self.process_group.size(), dim=sharding_dim)
        if rearrange_local_shards:
            # Need to re-arrange original shard_dim of output_tensor_list.
            local_shards = [local_shards[idx] for idx in indices]  # type: ignore[call-overload]
        local_result = reduce_scatter(
            torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op
        )

        sharded_tensor_size = self.local_shard.size()
        # Remove padding when the size is not divisible by the world size.
        if chunk_mode_res != 0:
            uneven_local_shards = self.local_shard.chunk(
                self.process_group.size(), dim=sharding_dim
            )
            expected_size = uneven_local_shards[rank_idx].size()
            if local_result.size() != expected_size:
                local_result = local_result.narrow(
                    sharding_dim,
                    0,
                    expected_size[sharding_dim],
                )
        return ShardedTensor._init_from_local_tensor(
            local_result,
            resharding_spec,
            sharded_tensor_size,
            process_group=self.process_group,
        )