示例#1
0
def _result_distribute_with_col_rearrange(results, input, world_size, weight,
                                          pg):
    """
    For col-wise sharding of weight, we need to distribute
    results to each rank. We do them in this function.
    Note that, if the index in the Sharding Spec is not equal to
    the rank number, we need to do the rearrangement based on the
    order given by the Sharding Spec (placement).

    Args:
        results: results from ops applied to inputs from all ranks.
            We need to distribute them back to their original ranks.
        input: tensor to be applied op to.
        world_size: number of ranks.
        weight: shareded weight tensor.
        pg: process group.

    Return: column rearranged result.
    """
    # Process results and outputs for all2all.
    sharding_dim = weight._sharding_spec.dim
    sharding_dim_size = weight.size(sharding_dim)
    dims = list(results[0].size())
    dims[0] = sharding_dim_size
    combined_results = torch.cat(results)
    output = torch.empty(*dims,
                         device=combined_results.device,
                         dtype=combined_results.dtype)

    # Compute output splits
    split_size = get_split_size(sharding_dim_size, world_size)
    output_split_sizes = [0] * world_size
    for idx, placement in enumerate(weight._sharding_spec.placements):
        output_split_sizes[placement.rank()] = get_chunked_dim_size(
            sharding_dim_size, split_size, idx)

    # distribute the outputs using all2all.
    output = all_to_all_single(output,
                               combined_results,
                               output_split_sizes=output_split_sizes,
                               group=pg)

    # Check if we need to rearrange columns appropriately for output.
    rearrange_columns = any([
        idx != placement.rank()
        for idx, placement in enumerate(weight._sharding_spec.placements)
    ])
    if not rearrange_columns:
        return output

    indices = []
    for placement in weight._sharding_spec.placements:
        dim_size = output_split_sizes[placement.rank()]
        start = sum([
            split_size if i < placement.rank() else 0
            for i, split_size in enumerate(output_split_sizes)
        ])
        indices += list(range(start, start + dim_size))

    return output.index_select(0, torch.tensor(indices, device=output.device))
示例#2
0
def reshuffle_local_shard(
    local_shard: torch.Tensor,
    st_size: torch.Size,
    sharding_spec: ShardingSpec,
    resharding_spec: ShardingSpec,
    pg: ProcessGroup,
) -> Tuple[List[Shard], List[ShardMetadata]]:
    """
    Reshuffle the local shard directly when the reshard dim is same as the original
    sharding dim. Logically we do this in two step:
    1. To collect all shards based on original sharding spec.
    2. Reshard the tensor based on the given resharding spec.

    In reality, we consolidate the two steps into one by sending the local tensor to
    the new shard directly based on the resharding spec.

    Args:
        local_tensor (Tensor): Local tensor stored in the current rank.
        st_size (torch.Size): The size of the sharded tensor.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
            specification describing how the tensor is sharded originally.
        resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The
            specification describing how the tensor will be resharded.
        pg (ProcessGroup): The process group to aggregate on.

    Returns:
        A Tuple of the followings:
            A List[`Shard`] which contains the local tensor and its metadata.
            A List[`ShardMetadata`] which contains the metadata for the shard, including
                offsets, lengths and device placement.
    """
    current_rank = dist.get_rank(pg)
    world_size = dist.get_world_size(pg)
    # Build shards_metadata first.
    shards_metadata, ranks = build_reshard_metadata(st_size, resharding_spec,
                                                    world_size)
    # Get input split size for all2all.
    reshard_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
    split_size = get_split_size(st_size[reshard_dim], world_size)
    input_split_sizes = [0] * world_size
    idx = get_idx_from_placements(sharding_spec.placements,
                                  current_rank)  # type: ignore[attr-defined]
    new_rank = resharding_spec.placements[idx].rank(
    )  # type: ignore[union-attr, attr-defined]
    input_split_sizes[new_rank] = local_shard.size(reshard_dim)
    # Get output split size for all2all.
    output_split_sizes = [0] * world_size
    new_idx = ranks.index(current_rank)
    sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size,
                                            new_idx)
    output_split_sizes[new_rank] = sharded_dim_size
    # Get gathered_input for all2all.
    local_shard = local_shard.transpose(0, reshard_dim).contiguous()
    gathered_input_size = list(local_shard.size())
    gathered_input_size[0] = sharded_dim_size
    gathered_input = torch.empty(gathered_input_size,
                                 device=local_shard.device)
    # all2all.
    local_shard = all_to_all_single(
        gathered_input,
        local_shard,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=pg,
    )
    local_tensor = local_shard.transpose(0, reshard_dim).contiguous()
    local_shards = [Shard(local_tensor, shards_metadata[current_rank])]
    return local_shards, shards_metadata
示例#3
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns: final result of linear operation.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform autograd enabled alltoall
    all_to_all_single(gathered_input,
                      input_t,
                      input_split_sizes=input_split_sizes,
                      group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    local_result = reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
示例#4
0
def _handle_row_wise_sharding_tensor(input, world_size, weight, rank,
                                     local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`_PartialTensor` object which stores the partial local result.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.transpose(0, -1).contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input_size = [input_split_sizes[rank] * world_size] + list(
        input_t_size[1:])
    gathered_input = torch.empty(gathered_input_size, device=input_t.device)

    # Perform autograd enabled alltoall
    all_to_all_single(gathered_input,
                      input_t,
                      input_split_sizes=input_split_sizes,
                      group=pg)

    # Reshape gathered_input appropriately for matmul
    shard_size = local_shard_t.size()[0]
    reshaped_inputs = [
        torch.narrow(gathered_input, 0, r * shard_size,
                     shard_size).transpose(0, -1) for r in range(world_size)
    ]
    reshaped_input = torch.cat(reshaped_inputs)
    if reshaped_input.dim() == 1:
        reshaped_input = reshaped_input.view(-1, local_shard_t.size(0))

    # Perform appropriate local matmul
    if reshaped_input.dim() <= 2:
        result = torch.addmm(_BiasTensorPartial.apply(world_size, bias),
                             reshaped_input, local_shard_t)
    else:
        result = reshaped_input.matmul(
            local_shard_t) + _BiasTensorPartial.apply(world_size, bias)

    # Return the partial local result.
    return _PartialTensor(result, pg)