Exemplo n.º 1
0
 def test_get_chunk_sharding_params(self):
     ranks = [
         "rank:0/cuda:0",
         "rank:1/cuda:1",
         "rank:2/cuda:2",
         "rank:3/cuda:3",
     ]
     spec = ChunkShardingSpec(
         dim=0,
         placements=ranks,
     )
     result = get_chunk_sharding_params(21, 4, spec, 1)
     self.assertEqual(6, result[0])
     self.assertEqual(6, result[1])
     result = get_chunk_sharding_params(21, 4, spec, 3)
     self.assertEqual(18, result[0])
     self.assertEqual(3, result[1])
     ranks[1], ranks[2] = ranks[2], ranks[1]
     ranks[0], ranks[3] = ranks[3], ranks[0]
     spec.placements = ranks
     result = get_chunk_sharding_params(21, 4, spec, 1)
     self.assertEqual(12, result[0])
     self.assertEqual(6, result[1])
     result = get_chunk_sharding_params(21, 4, spec, 3)
     self.assertEqual(0, result[0])
     self.assertEqual(6, result[1])
Exemplo n.º 2
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    gathered_inputs = all_gather(input, group=pg)
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)
    results = [None] * world_size
    indices = {}
    for idx, placement in enumerate(weight._sharding_spec.placements):
        indices[placement.rank()] = idx
    for i, inp in enumerate(gathered_inputs):
        results[indices[i]] = inp.matmul(local_shard_t) + local_bias
    # When the local result only has one dimension, we need to make sure
    # it does not shard by dim 0. So reshard can work properly.
    if results[0].dim() == 1:  # type: ignore[attr-defined]
        result = torch.stack(results)  # type: ignore[arg-type]
    else:
        result = torch.cat(results)  # type: ignore[arg-type]
    return _init_sharded_tensor_from_local_result(
        weight,
        result,
        0,
        -1,
        world_size,
        pg  # type: ignore[arg-type]
    )
Exemplo n.º 3
0
def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank):
    """
    Mask the input for embedding look-up for IDs which are not stored
    on the current rank. This function also adjust the ``padding_idx``
    so that it is only used on the rank where the corresponding row is
    stored.

    Note that, with ``max_norm`` flag on, only weights of rows being
    looked up will be re-normed. So we need an extra row for masked ID
    so that it does not affect the final result and ``max_norm``.

    Args:
        gather_inp: tensor to be applied op on gathered from all ranks.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
            Note that the embedding vector at padding_idx is
            excluded from the reduction.
        weight: weight tensor of Embedding look-up table.
        world_size: number of ranks.
        rank: # of cuda process.

    Returns:
        lookup_input: Tensor of masked input.
        padding_idx: adjusted padding_idx.
        padding_row: The extra row we used during lookup so that
            looking up does not affect ``max_norm``.
    """
    (start_pos,
     chunk_size) = get_chunk_sharding_params(weight.size(0), world_size,
                                             weight._sharding_spec, rank)
    mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size)
    lookup_input = gather_inp.clone() - start_pos
    lookup_input[mask] = chunk_size
    if (padding_idx is not None and padding_idx >= start_pos and padding_idx <
        (start_pos + chunk_size)):
        padding_idx = padding_idx - start_pos
    else:
        padding_idx = None

    # When max_norm is set, it will only re-norm the row being looked up.
    padding_row = torch.zeros(1,
                              weight.size(1),
                              device=gather_inp.device,
                              dtype=weight.dtype)
    return lookup_input, padding_idx, padding_row
Exemplo n.º 4
0
def _handle_col_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    """
    Entry-point function to handle the logic of col-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    When the local tensor only has one dimension, we increase one more dimension
    for reshard. We need to do squeeze manually to reduce the dimension later-on.

    For example, if we have:
    input: size[15]
    weight: size[15, 16]
    world_size: 4

    In each rank, we will have 4 * [4] tensors. We then stack them into a [4, 4]
    tensor and generate a sharded tenor sharded by dim 1.

    For the rest situations, we just simply concatenate local tensors. No more actions
    are needed afterward.

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns:
        A :class:`ShardedTensor` object which filled with local intermediate results.
    """
    # allgather the inputs first.
    out_size = list(input.size())
    out_size[0] = input.size(0) * dist.get_world_size(pg)
    output = torch.empty(out_size, device=input.device)
    output = _all_gather_base(output, input, group=pg)

    # Adjust bias and perform local matmul.
    (start_pos,
     chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                             weight._sharding_spec, rank)
    local_bias = _BiasTensorNarrow.apply(world_size, start_pos, chunk_size,
                                         weight, pg, bias)

    if output.dim() == 1:
        output = output.view(dist.get_world_size(pg), -1)

    if output.dim() <= 2:
        # Use fused version if possible.
        result = torch.addmm(local_bias, output, local_shard_t)
    else:
        result = output.matmul(local_shard_t) + local_bias

    # Build ShardedTensor as result.
    st_size = list(result.size())
    st_size[-1] = weight.size(0)
    new_sharding_spec = ChunkShardingSpec(
        dim=-1, placements=weight.sharding_spec().placements)
    return ShardedTensor._init_from_local_tensor(
        result,
        new_sharding_spec,
        *st_size,  # type: ignore[arg-type]
        process_group=pg,
    )
Exemplo n.º 5
0
def sharded_layer_norm(args, kwargs, pg):
    """
    Handles ``__torch_function__`` dispatch for the ``torch.nn.LayerNorm`` op.
    We gather all shards from local shards and perform a global normalization.
    We then scatter the result back to each rank.

    Args: same as ``torch.nn.LayerNorm``.

    Return:
        local_tensor (Tensor): New local tensor to build the sharded tensor.
        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
            sharding spec of the new sharded tensor.
        new_st_size (torch.Size): Size of the new sharded tensor.
    """
    st = args[0]
    normalized_shape = args[1]
    sharding_dim = st.sharding_spec().dim  # type: ignore[attr-defined]
    sharding_dim = sharding_dim if sharding_dim >= 0 else st.dim(
    ) + sharding_dim
    local_tensor = st.local_tensor()
    # If sharding dim is smaller than shape start, we just perform a local norm.
    shape_start = st.dim() - len(normalized_shape)
    if shape_start > sharding_dim:
        args = (local_tensor, *args[1:])
        local_tensor = torch.nn.functional.layer_norm(*args, **kwargs)
        return local_tensor, st.sharding_spec(), st.size()

    elementwise_affine = kwargs.get("elementwise_affine", False)
    eps = kwargs.get("eps", 1e-05)

    norm_dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
    local_size = math.prod(
        local_tensor.size()[shape_start:])  # type: ignore[attr-defined]
    st_size = math.prod(st.size()[shape_start:])  # type: ignore[attr-defined]
    local_mean = torch.mul(local_tensor.mean(norm_dims, keepdim=True),
                           local_size)
    global_mean = torch.div(all_reduce(local_mean), st_size)
    local_variant_sq = torch.square(local_tensor - global_mean).sum(
        norm_dims, keepdim=True)
    global_variant = torch.div(all_reduce(local_variant_sq), st_size)

    denom = torch.rsqrt(global_variant + eps)
    local_tensor = torch.mul(local_tensor - global_mean, denom)

    if elementwise_affine:
        weight = kwargs["weight"]
        bias = kwargs["bias"]
        current_rank = dist.get_rank(pg)  # type: ignore[attr-defined]
        world_size = dist.get_world_size(pg)
        (start_pos,
         chunk_size) = get_chunk_sharding_params(bias.size(0), world_size,
                                                 st.sharding_spec(),
                                                 current_rank)
        local_tensor = torch.addmm(
            torch.narrow(bias, 0, start_pos, chunk_size),
            local_tensor,
            torch.narrow(weight, sharding_dim - shape_start, start_pos,
                         chunk_size),
        )

    return local_tensor, st.sharding_spec(), st.size()