Beispiel #1
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    gathered_inputs = all_gather(input, group=pg)
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)
    results = [None] * world_size
    indices = {}
    for idx, placement in enumerate(weight._sharding_spec.placements):
        indices[placement.rank()] = idx
    for i, inp in enumerate(gathered_inputs):
        results[indices[i]] = inp.matmul(local_shard_t) + local_bias
    # When the local result only has one dimension, we need to make sure
    # it does not shard by dim 0. So reshard can work properly.
    if results[0].dim() == 1:  # type: ignore[attr-defined]
        result = torch.stack(results)  # type: ignore[arg-type]
    else:
        result = torch.cat(results)  # type: ignore[arg-type]
    return _init_sharded_tensor_from_local_result(
        weight,
        result,
        0,
        -1,
        world_size,
        pg  # type: ignore[arg-type]
    )
Beispiel #2
0
def _handle_col_wise_sharding(input, world_size, weight, local_shard, max_norm,
                              norm_type, padding_idx, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for embedding. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: col-wise shared local weight used for lookup.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
            Note: this will modify weight in-place.
        norm_type: The p in the p-norm to compute for the max_norm option.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
        pg: process group.

    Returns: final result of lookup.
    """
    if not isinstance(input, ReplicatedTensor):
        # allgather the inputs first for non Replicated Tensor.
        gathered_inputs = all_gather(input, group=pg)
    else:
        gathered_inputs = input

    if max_norm is not None:
        # max_norm changes the weight in-place
        local_shard = _handle_max_norm_col_wise(max_norm, norm_type,
                                                local_shard, input, world_size,
                                                gathered_inputs, pg)

    output = _handle_col_wise_sharding_base(
        torch.nn.functional.embedding,
        len(input.size()),
        input,
        world_size,
        weight,
        local_shard,
        pg,
        gathered_inputs,
        padding_idx=padding_idx,
    )
    return (output, local_shard)
Beispiel #3
0
def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg):
    """
    In case we need to gather input and all other parameters of embeddingBag
    ops, we need to stack all input together to perform ``all_gather``
    collective communication just once.

    Note that since offsets does not share the same size as input and
    is always smaller than input, we resize it during the communication.

    Args:
        input: tensor to be applied op on.
        per_sampe_weights: weights for weighted sum mode.
        offsets: when input is 1D. offsets determines the starting
            index position of each bag (sequence) in input.
        pg: process group.

    Returns:
        gathered_inputs: list of input tensor gathered from each rank.
        gathered_per_sample_weights: list of per_sample_weights from each rank.
        gathered_offsets: list of offsets from each rank.
    """
    input_to_gather = [input]
    if per_sample_weights is not None:
        input_to_gather.append(per_sample_weights)
    if offsets is not None:
        input_to_gather.append(offsets.clone().resize_(input.size()))
    gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg)

    gathered_per_sample_weights = None
    if per_sample_weights is not None:
        gathered_per_sample_weights = [t[1] for t in gathered_inputs]
    gathered_offsets = None
    if offsets is not None:
        idx = 2 if per_sample_weights is not None else 1
        gathered_offsets = [
            t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs
        ]
    gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs]
    return gathered_inputs, gathered_per_sample_weights, gathered_offsets
Beispiel #4
0
def _handle_col_wise_sharding_base(
    op_func,
    sharding_dim_size,
    col_dim,
    input,
    world_size,
    weight,
    local_shard,
    pg,
    gathered_inputs=None,
    mode=None,
    gathered_per_sample_weights=None,
    gathered_offsets=None,
    padding_idx=None,
):
    """
    For col-wise sharding of weight, lots of logic are common.
    So we extract the common logic and put in this function:
    Step 1. To get input from each rank and
    Step 2. To perform the op on the concatenated tensor.
    Step 3. To distribute results to each rank with col rearrangement.
    Step 4. To concatenate all results from all ranks.

    Args:
        op_func: operator which is applied to the input tensor.
        sharding_dim_size: the max size of the column each rank gets.
        col_dim: dim of result tensor after the operation.
        input: tensor to be applied op on.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: col-wise sharded weight tensor.
        pg: process group.
        gathered_inputs: list of inputs from all ranks. If specified, we
            don't need to communicate with each rank any more.
        mode: aggregation mode of EmbeddingBag.
        gathered_per_sample_weights: per_sample_weights across all ranks.
        gathered_offsets: offsets across all ranks.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
            Note that the embedding vector at padding_idx is
            excluded from the reduction.

    Return: final result of input being applied with the op.
    """
    if gathered_inputs is None:
        # allgather the inputs first.
        gathered_inputs = all_gather(input, group=pg)

    # run the operator's function for all the inputs.
    results = []
    for i, inp in enumerate(gathered_inputs):
        if op_func == torch.nn.functional.embedding_bag:
            result = op_func(
                inp,
                local_shard,
                offsets=gathered_offsets[i]
                if gathered_offsets is not None else None,
                mode=mode,
                per_sample_weights=gathered_per_sample_weights[i]
                if gathered_per_sample_weights is not None else None,
                padding_idx=padding_idx,
            )
        elif op_func == torch.nn.functional.embedding:
            result = op_func(
                inp,
                local_shard,
                padding_idx=padding_idx,
            )
        else:
            result = op_func(inp, local_shard)
        results.append(torch.transpose(result, 0, col_dim))

    # Distribute results to each rank with col rearrangement.
    output = _result_distribute_with_col_rearrange(results, input,
                                                   sharding_dim_size,
                                                   world_size, weight, pg)

    # transpose the output and return result.
    return torch.transpose(output, 0, col_dim)