Esempio n. 1
0
 def forward(ctx: Any, group: dist.ProcessGroup,
             input: Tensor) -> Tensor:  # type: ignore
     ctx.group = group
     input = input.contiguous()
     output = torch.empty_like(input)
     dist.all_to_all_single(output, input, group=group)
     return output
Esempio n. 2
0
def timed_all_to_all(input, output, args):
    if args.dist == 'torch':
        import torch.distributed as dist
    elif args.dist == 'deepspeed':
        import deepspeed.comm as dist

    sync_all()
    # Warmups, establish connections, etc.
    for i in range(args.warmups):
        dist.all_to_all_single(output, input, async_op=args.async_op)
    sync_all()

    # time the actual comm op trials times and average it
    pre = time.perf_counter()
    for i in range(args.trials):
        dist.all_to_all_single(output, input, async_op=args.async_op)
    sync_all()
    duration = time.perf_counter() - pre

    # maintain and clean performance data
    avg_duration = duration / args.trials
    size = input.element_size() * input.nelement()
    n = dist.get_world_size()
    tput, busbw = get_bw('all_to_all', size, avg_duration, args)
    tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration)
    desc = f'{input.nelement()}x{input.element_size()}'

    if not args.raw:
        size = convert_size(size)

    print_rank_0(
        f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
Esempio n. 3
0
def _communicate_size_to_each_rank(input_size_list,
                                   output_size,
                                   input,
                                   pg,
                                   tensor_type=torch.int):
    """
    In the circumstance of row-wise sharding of weight, we need to first
    communicate the input length to each rank because each rank gets a
    different one.

    Args:
        input_size_list: list of sizes to be sent to each rank.
        output_size: length of the output tensor.
        input: tensor to be applied op on.
        pg: process group.
        tensor_type: dtype of tensor.

    Return: A list of communication results (int).
    """
    input_size_list_tensor = torch.tensor(input_size_list,
                                          dtype=tensor_type,
                                          device=input.device)
    output_size_list_tensor = torch.empty(output_size,
                                          dtype=tensor_type,
                                          device=input.device)
    dist.all_to_all_single(
        output_size_list_tensor,
        input_size_list_tensor,
        group=pg,
    )
    return output_size_list_tensor.tolist()
Esempio n. 4
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(
                range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(
            0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input,
                           input_t,
                           input_split_sizes=input_split_sizes,
                           group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Esempio n. 5
0
 def _alltoall(self, dispatched_attention):
     if dist.get_world_size(group=self.ep_group) > 1:
         dispatched_input = torch.empty_like(dispatched_attention)
         dist.all_to_all_single(dispatched_input,
                                dispatched_attention,
                                group=self.ep_group)
         return dispatched_input
     else:
         return dispatched_attention
Esempio n. 6
0
def _result_distribute_with_col_rearrange(results, input, sharding_dim_size,
                                          world_size, weight, pg):
    """
    For col-wise sharding of weight, we need to distribute
    results to each rank. We do them in this function.
    Note that, if the index in the Sharding Spec is not equal to
    the rank number, we need to do the rearrangement based on the
    order given by the Sharding Spec (placement).

    Args:
        results: results from ops applied to inputs from all ranks.
            We need to distribute them back to their original ranks.
        input: tensor to be applied op to.
        sharding_dim_size: the max size of the column each rank gets.
        world_size: number of ranks.
        weight: shareded weight tensor.
        pg: process group.

    Return: column rearranged result.
    """
    # Process results and outputs for all2all.
    dims = list(results[0].size())
    dims[0] = sharding_dim_size
    output = torch.empty(*dims, device=input.device)
    combined_results = torch.cat(results)

    # Compute output splits
    split_size = get_split_size(sharding_dim_size, world_size)
    output_split_sizes = [0] * world_size
    for idx, placement in enumerate(weight._sharding_spec.placements):
        output_split_sizes[placement.rank()] = get_chunked_dim_size(
            sharding_dim_size, split_size, idx)

    # distribute the outputs using all2all.
    dist.all_to_all_single(output,
                           combined_results,
                           output_split_sizes=output_split_sizes,
                           group=pg)

    # Check if we need to rearrange columns appropriately for output.
    rearrange_columns = any([
        idx != placement.rank()
        for idx, placement in enumerate(weight._sharding_spec.placements)
    ])
    if not rearrange_columns:
        return output

    indices = []
    for placement in weight._sharding_spec.placements:
        dim_size = output_split_sizes[placement.rank()]
        start = sum([
            split_size if i < placement.rank() else 0
            for i, split_size in enumerate(output_split_sizes)
        ])
        indices += list(range(start, start + dim_size))

    return output.index_select(0, torch.tensor(indices, device=output.device))
Esempio n. 7
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t,
                              bias, pg):
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size,
                                                idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[int] = []
        for placement in weight._sharding_spec.placements:
            sharded_dim_size = get_chunked_dim_size(input_t_size[0],
                                                    split_size,
                                                    placement.rank())
            input_idx = placement.rank() * split_size
            indices += range(input_idx, input_idx + sharded_dim_size)

        input_t = input_t.index_select(
            0, torch.tensor(indices, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size,
                                 input_t_size[1],
                                 device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input,
                           input_t,
                           input_split_sizes=input_split_sizes,
                           group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Esempio n. 8
0
    def wrapper(*args, **kwargs):
        group = kwargs.get('group', None)
        async_op = kwargs.get('async_op', False)
        if (async_op is True):
            raise RuntimeError('The async_op=True mode is not supported yet.')
        if (func == dist.all_gather):
            tensors = args[0]
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_gather(out_tensors,
                            input_tensors,
                            group=group,
                            async_op=async_op)
            for i, t in enumerate(
                    _dequantize_tensor_list(out_tensors,
                                            qtype,
                                            quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all):
            tensors = args[0]
            input_tensors = _quantize_tensor_list(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_to_all(out_tensors,
                            input_tensors,
                            group=group,
                            async_op=async_op)
            for i, t in enumerate(
                    _dequantize_tensor_list(out_tensors,
                                            qtype,
                                            quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all_single):
            tensors = args[0]
            out_splits = kwargs.get('out_splits', None)
            in_splits = kwargs.get('in_splits', None)
            # Quantizing the input/output tensor
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor(tensors, qtype)
            dist.all_to_all_single(out_tensors,
                                   input_tensors,
                                   out_splits,
                                   in_splits,
                                   group=group)
            for i, t in enumerate(
                    _dequantize_tensor(out_tensors,
                                       qtype,
                                       quant_loss=quant_loss)):
                tensors[i] = t
        else:
            raise RuntimeError(
                f"The collective op {func} is not supported yet")
Esempio n. 9
0
def reshard_flatten_tensor(
    input_tensor: ShardedTensor,
    output_spec: ShardingSpec,
    world_size: int,
    my_rank: int,
    device: torch.device,
    process_group: Optional[dist.ProcessGroup],
) -> torch.Tensor:
    """
    Resharded a sharded flatten tensor, this is used by FSDP to do sharded
    state_dict. But the functionaility is not supported by ShardedTensor.
    This API is designed to be used for FSDP; therefore this API supports only
    1-D ShardedTensor (hence the naming, reshard_flatten_tensor).

    This API uses the ChunkShardingSpec and EnumerableShardingSpec from
    torch.distributed.sharding_spec but ignores the placement field in
    ChunkShardingSpec, as the placement requires the callees understand the
    number of GPUs per node. The API simply uses the semantics of the sharding
    specs.

    Args:
        input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D.
        output_spec (ShardingSpec): the sharding spect for the output tensor.
        world_size (int): total trainer count.
        my_rank (int): the rank for this trainer.

    Returns:
        The local shard for the new ShardedTensor.
    """

    input_spec = input_tensor.sharding_spec()
    size = input_tensor.size()
    if isinstance(size, int):
        raise ValueError("The input tensor has no dimensions.")
    tensor_numel = size.numel()
    input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel,
                                              world_size)
    output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel,
                                               world_size)
    input_split_sizes, output_split_sizes = _offsets_to_split_sizes(
        input_offsets, output_offsets, tensor_numel, world_size, my_rank)
    output_size = sum(output_split_sizes)
    local_shard = torch.empty(output_size,
                              dtype=input_tensor.dtype,
                              device=device)
    dist.all_to_all_single(
        local_shard,
        input_tensor.local_shards()[0].tensor,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=process_group,
    )
    return local_shard
Esempio n. 10
0
 def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
     ctx.group = group
     ctx.input_size = input.size()
     ctx.output_split_sizes = input_split_sizes
     ctx.input_split_sizes = output_split_sizes
     dist.all_to_all_single(
         output,
         input,
         output_split_sizes=output_split_sizes,
         input_split_sizes=input_split_sizes,
         group=group,
     )
     return output
Esempio n. 11
0
def _handle_col_wise_sharding(input, world_size, weight, local_shard_t, bias,
                              pg):
    # allgather the inputs first.
    gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)]
    dist.all_gather(gathered_inputs, input, group=pg)

    # matmul all the inputs.
    results = []
    for i, inp in enumerate(gathered_inputs):
        results.append(inp.matmul(local_shard_t).t())

    # Process inputs and outputs for all2all.
    sharding_dim_size = weight.size()[0]
    output = torch.empty((sharding_dim_size, input.size(0)),
                         device=input.device)
    combined_results = torch.cat(results)

    # Compute output splits
    split_size = get_split_size(sharding_dim_size, world_size)
    output_split_sizes = [
        get_chunked_dim_size(sharding_dim_size, split_size, placement.rank())
        for placement in weight._sharding_spec.placements
    ]

    # distribute the outputs using all2all.
    dist.all_to_all_single(output,
                           combined_results,
                           output_split_sizes=output_split_sizes,
                           group=pg)

    # Check if we need to rearrange rows appropriately for output.
    rearrange_rows = any([
        idx != placement.rank()
        for idx, placement in enumerate(weight._sharding_spec.placements)
    ])
    if rearrange_rows:
        indices = []
        for placement in weight._sharding_spec.placements:
            dim_size = output_split_sizes[placement.rank()]
            start = sum([
                split_size if i < placement.rank() else 0
                for i, split_size in enumerate(output_split_sizes)
            ])
            indices += list(range(start, start + dim_size))

        output = output.index_select(
            0, torch.tensor(indices, device=output.device))

    # add bias and return result.
    return output.t() + bias
Esempio n. 12
0
 def forward(ctx, a2ai, *inputs):
     global myreq
     mb_split_lengths = a2ai.gNS
     if mb_split_lengths:
         mb_split_lengths = [m * a2ai.lS * a2ai.E for m in mb_split_lengths]
     emb_split_lengths = a2ai.gSS
     if emb_split_lengths:
         emb_split_lengths = [
             a2ai.lN * e * a2ai.E for e in emb_split_lengths
         ]
     input = torch.cat(inputs, dim=1).view([-1])
     output = input.new_empty([a2ai.S * a2ai.lN * a2ai.E])
     req = dist.all_to_all_single(output,
                                  input,
                                  emb_split_lengths,
                                  mb_split_lengths,
                                  async_op=True)
     myreq.req = req
     myreq.tensor = []
     myreq.tensor.append(output)
     myreq.tensor = tuple(myreq.tensor)
     a2ai.mb_split_lengths = mb_split_lengths
     a2ai.emb_split_lengths = emb_split_lengths
     myreq.a2ai = a2ai
     ctx.a2ai = a2ai
     return myreq.tensor
Esempio n. 13
0
def alltoall_test():
    global comm_rank, comm_size
    dev = torch.device('cuda')
    t_send = torch.zeros(comm_size, device=dev) + comm_rank
    t_recv = torch.zeros(comm_size, device=dev)
    dist.all_to_all_single(t_recv, t_send)
    t_recv = t_recv + 1
    dist.all_reduce(t_recv)
    if torch.all(
            torch.eq(
                t_recv,
                comm_size *
                torch.arange(start=1, end=comm_size + 1, device=dev))):
        print(f"Rank {comm_rank}: success")
    else:
        print(f"Rank {comm_rank}: failed")
    def all_to_allv(self, collectiveArgs, retFlag=False, pair=False):
        # pair=True mode does not support quantization
        if (
            collectiveArgs.all2all_qcomm
            and collectiveArgs.ipTensor.dtype == torch.float32
            and (
                collectiveArgs.opTensor.nelement() >= collectiveArgs.quan_threshold
                or collectiveArgs.ipTensor.nelement() >= collectiveArgs.quan_threshold
            )
            and not pair
        ):
            work = all_to_allv_internal(collectiveArgs)
        else:
            work = dist.all_to_all_single(
                collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
                collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
                collectiveArgs.opTensor_split if not pair else collectiveArgs.opTensor_split_pair,
                collectiveArgs.ipTensor_split if not pair else collectiveArgs.ipTensor_split_pair,
                group=collectiveArgs.group,
                async_op=collectiveArgs.asyncOp,
            )

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(work)

        if retFlag:
            return work
Esempio n. 15
0
    def forward(ctx, a2a_info, *inputs):
        global myreq
        with record_function("DLRM alltoall_req_fwd_single"):
            batch_split_lengths = a2a_info.global_batch_partition_slices
            if batch_split_lengths:
                batch_split_lengths = [
                    m * a2a_info.emb_dim * a2a_info.local_table_num
                    for m in batch_split_lengths
                ]
            table_split_lengths = a2a_info.global_table_wise_parition_slices
            if table_split_lengths:
                table_split_lengths = [
                    a2a_info.local_batch_num * e * a2a_info.emb_dim
                    for e in table_split_lengths
                ]
            input = torch.cat(inputs, dim=1).view([-1])
            output = input.new_empty([
                a2a_info.global_table_num * a2a_info.local_batch_num *
                a2a_info.emb_dim
            ])
            req = dist.all_to_all_single(output,
                                         input,
                                         table_split_lengths,
                                         batch_split_lengths,
                                         async_op=True)

            myreq.req = req
            myreq.tensor = []
            myreq.tensor.append(output)
            myreq.tensor = tuple(myreq.tensor)
            a2a_info.batch_split_lengths = batch_split_lengths
            a2a_info.table_split_lengths = table_split_lengths
            myreq.a2a_info = a2a_info
            ctx.a2a_info = a2a_info
            return myreq.tensor
Esempio n. 16
0
def all_to_all(tensor, group):
    """Perform an all-to-all operation on a 1D Tensor."""
    assert tensor.dim() == 1
    split_count = get_world_size(group=group)
    assert tensor.numel() % split_count == 0
    if use_xla():
        assert isinstance(group, tuple) and group[0] == "tpu"
        return xm.all_to_all(
            tensor,
            split_dimension=0,
            concat_dimension=0,
            split_count=split_count,
            groups=group[1],
        )
    else:
        output = torch.zeros_like(tensor)
        dist.all_to_all_single(output, tensor, group=group)
        return output
Esempio n. 17
0
 def all_to_all(self, collectiveArgs, retFlag=False):
     retObj = dist.all_to_all_single(
         collectiveArgs.opTensor,
         collectiveArgs.ipTensor,
         group=collectiveArgs.group,
         async_op=collectiveArgs.asyncOp,
     )  # synchronicity is maintained in runColl
     if retFlag:
         return retObj
     else:
         return
Esempio n. 18
0
 def backward(ctx, *grad_outputs):
     global myreq
     #print("All2All_Wait:backward")
     a2ai = ctx.a2ai
     grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs]
     grad_output = torch.cat(grad_outputs)
     grad_input = grad_output.new_empty([a2ai.N * a2ai.lS * a2ai.E])
     req = dist.all_to_all_single(grad_input, grad_output, a2ai.mb_split_lengths, a2ai.emb_split_lengths, async_op=True)
     myreq.req = req
     myreq.tensor = grad_input
     return (grad_output,)
Esempio n. 19
0
 def all_to_allv(self, collectiveArgs, retFlag=False):
     if collectiveArgs.all2all_qcomm:
         work = all_to_allv_internal(collectiveArgs)
     else:
         work = dist.all_to_all_single(
             collectiveArgs.opTensor,
             collectiveArgs.ipTensor,
             collectiveArgs.opTensor_split,
             collectiveArgs.ipTensor_split,
             group=collectiveArgs.group,
             async_op=collectiveArgs.asyncOp,
         )
     if retFlag:
         return work
Esempio n. 20
0
    def all_to_all(self, collectiveArgs: collectiveArgsHolder, retFlag=False):
        if collectiveArgs.all2all_qcomm:
            work = all_to_all_internal(collectiveArgs)
        else:
            work = dist.all_to_all_single(
                collectiveArgs.opTensor,
                collectiveArgs.ipTensor,
                None,
                None,
                group=collectiveArgs.group,
                async_op=collectiveArgs.asyncOp,
            )

        if retFlag:
            return work
Esempio n. 21
0
 def backward(ctx, *grad_outputs):
     global myreq
     with record_function("DLRM alltoall_wait_bwd_single"):
         a2a_info = ctx.a2a_info
         grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs]
         grad_output = torch.cat(grad_outputs)
         grad_input = grad_output.new_empty(
             [a2a_info.batch_size * a2a_info.local_table_num * a2a_info.emb_dim]
         )
         req = dist.all_to_all_single(
             grad_input,
             grad_output,
             a2a_info.batch_split_lengths,
             a2a_info.table_split_lengths,
             async_op=True,
         )
         myreq.req = req
         myreq.tensor = grad_input
         return (grad_output,)
    def all_to_all(self, collectiveArgs: collectiveArgsHolder, retFlag=False, pair=False):
        # pair=True mode does not support quantization
        if collectiveArgs.all2all_qcomm and not pair:
            work = all_to_all_internal(collectiveArgs)
        else:
            work = dist.all_to_all_single(
                collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
                collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
                None,
                None,
                group=collectiveArgs.group,
                async_op=collectiveArgs.asyncOp,
            )

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(work)

        if retFlag:
            return work
Esempio n. 23
0
def init_distributed(rank=-1,
                     local_rank=-1,
                     size=-1,
                     use_gpu=False,
                     backend=""):
    global myreq
    global my_rank
    global my_size
    global my_local_rank
    global my_local_size
    global a2a_impl
    global alltoall_supported

    # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2)
    num_mpi_ranks = env2int([
        "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE"
    ])
    if backend == "" and num_mpi_ranks > 1:
        if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0:
            backend = "ccl"
        elif use_gpu and dist.is_nccl_available():
            backend = "nccl"
        elif dist.is_mpi_available():
            backend = "mpi"
        else:
            print(
                "WARNING: MPI multi-process launch detected but PyTorch MPI backend not available."
            )
            backend = "gloo"

    if backend != "":
        # guess Rank and size
        if rank == -1:
            rank = env2int([
                "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK",
                "RANK"
            ], 0)
        if size == -1:
            size = env2int(
                [
                    "PMI_SIZE",
                    "OMPI_COMM_WORLD_SIZE",
                    "MV2_COMM_WORLD_SIZE",
                    "WORLD_SIZE",
                ],
                1,
            )
        if not os.environ.get("RANK", None) and rank != -1:
            os.environ["RANK"] = str(rank)
        if not os.environ.get("WORLD_SIZE", None) and size != -1:
            os.environ["WORLD_SIZE"] = str(size)
        if not os.environ.get("MASTER_PORT", None):
            os.environ["MASTER_PORT"] = "29500"
        if not os.environ.get("MASTER_ADDR", None):
            local_size = env2int(
                [
                    "MPI_LOCALNRANKS",
                    "OMPI_COMM_WORLD_LOCAL_SIZE",
                    "MV2_COMM_WORLD_LOCAL_SIZE",
                ],
                1,
            )
            if local_size != size and backend != "mpi":
                print(
                    "Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default"
                )
                print(
                    "If this run hangs, try exporting rank 0's hostname as MASTER_ADDR"
                )
            os.environ["MASTER_ADDR"] = "127.0.0.1"

    if size > 1:
        if local_rank == -1:
            my_local_rank = env2int(
                [
                    "MPI_LOCALRANKID",
                    "OMPI_COMM_WORLD_LOCAL_RANK",
                    "MV2_COMM_WORLD_LOCAL_RANK",
                    "LOCAL_RANK",
                ],
                0,
            )
        else:
            my_local_rank = local_rank
        my_local_size = env2int(
            [
                "MPI_LOCALNRANKS",
                "OMPI_COMM_WORLD_LOCAL_SIZE",
                "MV2_COMM_WORLD_LOCAL_SIZE",
            ],
            1,
        )
        if use_gpu:
            if my_local_size > torch.cuda.device_count():
                print(
                    "Not sufficient GPUs available... local_size = %d, ngpus = %d"
                    % (my_local_size, torch.cuda.device_count()))
                sys.exit(1)
            torch.cuda.set_device(my_local_rank)
        dist.init_process_group(backend, rank=rank, world_size=size)
        my_rank = dist.get_rank()
        my_size = dist.get_world_size()
        if my_rank == 0:
            print("Running on %d ranks using %s backend" % (my_size, backend))
        if hasattr(dist, "all_to_all_single"):
            try:
                t = torch.zeros([4])
                if use_gpu:
                    t = t.cuda()
                dist.all_to_all_single(t, t)
                alltoall_supported = True
            except RuntimeError as err:
                print("fail to enable all_to_all_single primitive: %s" % err)
        if a2a_impl == "alltoall" and alltoall_supported == False:
            print(
                "Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall"
                % (a2a_impl, backend))
            a2a_impl = "scatter"
        if a2a_impl != "":
            print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl)
    else:
        my_rank = 0
        my_size = 1
        my_local_rank = 0
        my_local_size = 1
    print_all("world size: %d, current rank: %d, local rank: %d" %
              (my_size, my_rank, my_local_rank))
    myreq = Request()
Esempio n. 24
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for Linear. (Detailed explanations of the logic can be found in the
    comment for sharded_linear.)

    Args:
        input: matrix to be multiplied with the sharded weight.
        world_size: number of ranks.
        weight: shareded weight tensor.
        rank: # of cuda process.
        local_shard_t: row-wise shared local weight used for lookup.
        bias: bias term of linear op.
        pg: process group.

    Returns: final result of linear operation.
    """
    # alltoall to gather all the appropriate inputs.
    input_t = input.t().contiguous()
    input_t_size = input_t.size()

    # Compute expected size
    split_size = get_split_size(input_t_size[0], world_size)
    input_split_sizes = [0] * world_size
    rearrange_rows = False

    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx)
        input_split_sizes[placement.rank()] = sharded_dim_size
        if placement.rank() != idx:
            rearrange_rows = True

    if rearrange_rows:
        # Need to re-arrange rows of input_t for all2all.
        indices: List[List[int]] = [[0]] * world_size
        # When we do the chunk split, we always ensure the first N - 1 chunks get max out
        # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
        # are not possible. The expected split size will be [4, 4, 4, 1].
        sharded_dim_size_max = max(input_split_sizes)
        for idx, placement in enumerate(weight._sharding_spec.placements):
            split_size = input_split_sizes[placement.rank()]
            offset_start_idx = idx * sharded_dim_size_max
            indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device))

    gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device)

    # Perform alltoall
    dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg)
    gathered_input = gathered_input.t()

    # Perform local matmuls for all shards
    shard_size = local_shard_t.size()[0]
    results = []
    for r in range(world_size):
        inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size)
        results.append(inp.matmul(local_shard_t))

    # Gather all the results appropriately.
    local_result = torch.empty_like(results[rank])
    dist.reduce_scatter(local_result, results, group=pg)

    # Return the appropriate local result.
    return local_result + bias
Esempio n. 25
0
    def compressed_allreduce(self, buffer_m: torch.tensor, worker_error,
                             server_error, local_rank):

        # all_start_time = time.time()
        original_shape = buffer_m.size()
        if len(original_shape) > 1:
            buffer_m = torch.flatten(buffer_m)
        original_size = buffer_m.numel()
        worker_error_size = worker_error.numel()
        cupy.cuda.Device(local_rank).use()

        if original_size != worker_error_size:
            empty_tensor = torch.zeros(worker_error_size - original_size,
                                       device=buffer_m.device)
            buffer_m = torch.cat([buffer_m, empty_tensor])

        buffer_m.add_(worker_error)
        worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m))
        worker_error.set_(
            buffer_m - worker_scale *
            buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0))

        cupy_sign_list_packed = self.compression_backend.compress_by_chunk(
            self.compression_backend.torch2cupy(
                buffer_m.sign_().add_(1).bool()), self.size)
        cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale)

        cupy_recvbuf_sign = cupy.zeros(
            [self.size, cupy_sign_list_packed[self.rank].size],
            dtype=cupy_sign_list_packed[0].dtype)
        # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype)

        sign_list_packed = [
            self.compression_backend.cupy2torch(cupy_sign_list_packed[idx])
            for idx in range(self.size)
        ]

        # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale)
        recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign)
        #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale)
        recvbuf_scale = [
            torch.zeros(1,
                        dtype=worker_scale.dtype,
                        device=torch.device(local_rank))
            for i in range(self.size)
        ]

        # communication phase 1
        # gather_start = time.time()
        # Alltoall for sign
        dist.all_to_all_single(recvbuf_sign,
                               torch.stack(sign_list_packed),
                               group=self.world_group)
        # Allgather for scale
        dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group)

        # gather_end = time.time()

        # cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None
        cupy_sign_list_packed = None

        cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign)
        #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale))

        compensated_server_m = self.compression_backend.cupy2torch(
            (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape(
                self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
                    torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0)
        compensated_server_m.add_(server_error)
        server_scale = torch.norm(compensated_server_m) / np.sqrt(
            compensated_server_m.numel())
        server_error.set_(compensated_server_m -
                          server_scale * compensated_server_m.sign().add_(
                              1).bool().float().add_(-0.5).mul_(2.0))

        # cupy_server_scale = self.compression_backend.torch2cupy(server_scale)

        cupy_server_sign_packed = self.compression_backend.compress_by_chunk(
            self.compression_backend.torch2cupy(
                compensated_server_m.sign_().add_(1).bool()), 1)
        compensated_server_m = None

        cupy_recvbuf_sign_server = cupy.zeros(
            [self.size, cupy_server_sign_packed[0].size],
            dtype=cupy_recvbuf_sign.dtype)
        # cupy_recvbuf_sign, recvbuf_sign = None, None
        cupy_recvbuf_sign = None

        server_sign_packed = [
            self.compression_backend.cupy2torch(cupy_server_sign_packed[0])
        ]
        recvbuf_sign_server = [
            self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx])
            for idx in range(self.size)
        ]

        # server_scale = self.compression_backend.cupy2torch(cupy_server_scale)
        cupy_recvbuf_scale_server = cupy.zeros([self.size, 1],
                                               dtype=cupy_worker_scale.dtype)
        # cupy_recvbuf_scale, recvbuf_scale = None, None

        recvbuf_scale_server = [
            self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx])
            for idx in range(self.size)
        ]

        # Communication Phase 2
        dist.all_gather(recvbuf_sign_server,
                        server_sign_packed[0],
                        group=self.world_group)
        dist.all_gather(recvbuf_scale_server,
                        server_scale,
                        group=self.world_group)

        cupy_server_sign_packed = None

        # need to convert from a tensor list to a single tensor
        # dist.all_gather only provides a tensor list as the recv/output buffer
        recvbuf_sign_server = torch.stack(recvbuf_sign_server)

        cupy_recvbuf_sign_server = self.compression_backend.torch2cupy(
            recvbuf_sign_server)

        buffer_m.data.copy_(
            self.compression_backend.cupy2torch(
                (cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape(
                    self.size, -1)).float().add_(-0.5).mul_(2.0).mul_(
                        self.compression_backend.cupy2torch(
                            cupy_recvbuf_scale_server)).flatten().data)
        if original_size != worker_error_size:
            buffer_m = buffer_m[0:original_size]
        if len(original_shape) > 1:
            buffer_m = buffer_m.reshape(original_shape)

        return buffer_m
Esempio n. 26
0
    t1 = t1.cuda()
    t2 = t2.cuda()

if args.op == "p2p":
    if rank == 0:
        dist.send(t, 1)
    else:
        dist.recv(t, 0)
elif args.op == "broadcast":
    dist.broadcast(t, 0)
elif args.op == "allreduce":
    dist.all_reduce(t, op=dist.ReduceOp.SUM)
elif args.op == "reduce":
    dist.reduce(t, 0, op=dist.ReduceOp.SUM)
elif args.op == "alltoall":
    dist.all_to_all_single(t2, t)
elif args.op == "alltoallv":
    out_split = [1] * size
    in_split = [1] * size
    dist.all_to_all_single(t2, t, out_split, in_split)
elif args.op == "allgather":
    dist.all_gather([t1, t2], t)

else:
    print("Incorrect operation")
    sys.exit(1)

#dist.barrier()
print('rank ', rank, ':', t, ":", t1, ":", t2)
dist.destroy_process_group()
Esempio n. 27
0
def shuffle_data(inputs):
    input = torch.cat(inputs)
    output = input.new_empty(input.size())
    req = dist.all_to_all_single(output, input) 
    output = output.reshape(my_size, -1)
    return output
Esempio n. 28
0
def _handle_row_wise_sharding(
    input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg
):
    """
    Entry-point function to handle the logic of row-wise sharding of weight
    for embedding. (Detailed explanations of the logic can be found in
    the comment for sharded_embedding.)

    Args:
        input: list of ID used for lookup and aggregation.
        world_size: number of ranks.
        weight: shareded weight tensor.
        local_shard: row-wise shared local weight used for lookup.
        max_norm: If given, each embedding vector with norm larger
            than max_norm is renormalized to have norm max_norm.
            Note: this will modify weight in-place.
        norm_type: The p in the p-norm to compute for the max_norm option.
        padding_idx: If specified, the entries at padding_idx do
            not contribute to the gradient; therefore, the embedding
            vector at padding_idx is not updated during training,
            i.e. it remains as a fixed “pad”.
        rank: # of cuda process.
        pg: process group.

    Returns: final result of lookup.
    """
    # flatten the ids across all input and sort
    input_size = input.size()
    input_1d = torch.reshape(input, (-1,)).contiguous()
    input_sorted, indices_1d = torch.sort(input_1d)
    rearrange_indices_1d = torch.argsort(indices_1d)
    input_sorted.contiguous()

    (
        input_sorted,
        input_split_sizes,
        sharded_dim_size_max,
        _,
        rearrange_indices_1d_second_order,
        padding_idx,
    ) = _handle_row_wise_lookup_distribute(
        input_sorted, input, world_size, weight, rank, padding_idx
    )

    # Get the input split size to be sent from each rank to the current rank.
    # We can then infer the output split size.
    output_split_sizes = _communicate_size_to_each_rank(
        input_split_sizes, world_size, input, pg
    )

    # Input sent from each rank to the current rank may have different sizes.
    gathered_input = torch.empty(
        sum(output_split_sizes), dtype=torch.int64, device=input.device
    )

    # Perform the modular operation of the 1D tensor to be sent to each rank.
    input_sorted = torch.remainder(input_sorted, sharded_dim_size_max)

    # Perform alltoall
    dist.all_to_all_single(
        gathered_input,
        input_sorted,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=pg,
    )

    # If input is None, passing in max_norm causes
    # errors in CUDA.
    if max_norm is not None and gathered_input.size(0) == 0:
        max_norm = None

    # Perform local embedding look up.
    gathered_input_embeddings = torch.nn.functional.embedding(
        gathered_input,
        local_shard,
        padding_idx=padding_idx,
        max_norm=max_norm,
        norm_type=norm_type,
    )

    # Gather all lookup result appropriately by performing alltoall again
    gathered_output = torch.empty(
        input_sorted.size(0), weight.size(1), device=input.device
    )
    dist.all_to_all_single(
        gathered_output,
        gathered_input_embeddings,
        input_split_sizes=output_split_sizes,
        output_split_sizes=input_split_sizes,
        group=pg,
    )

    # Rearrange the results to its original shape.
    if rearrange_indices_1d_second_order is not None:
        gathered_output = gathered_output[rearrange_indices_1d_second_order]
    gathered_output = gathered_output[rearrange_indices_1d]

    # Return the appropriate local result.
    return torch.reshape(gathered_output, (*input_size, weight.size(1)))
Esempio n. 29
0
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard,
                              pg):
    # flatten the ids across all input and sort
    input_size = input.size()
    input_1d = torch.reshape(input, (-1, )).contiguous()
    input_sorted, indices_1d = torch.sort(input_1d)
    rearrange_indices_1d = torch.argsort(indices_1d)
    input_sorted.contiguous()

    # Decide which rank the input goes to by check the sharding range.
    split_size = get_split_size(weight.size(0), world_size)
    rearrange_rows = False
    input_split_sizes: List[int] = [0] * world_size
    input_split_start_indices: List[int] = [0] * world_size
    # When we do the chunk split, we always ensure the first N - 1 chunks get max out
    # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
    # are not possible. The expected split size will be [4, 4, 4, 1].
    sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0)
    for idx, placement in enumerate(weight._sharding_spec.placements):
        sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size,
                                                idx)
        start_row_idx = idx * sharded_dim_size_max
        end_row_idx = start_row_idx + sharded_dim_size
        start_idx = torch.searchsorted(input_sorted, start_row_idx).item()
        end_idx = torch.searchsorted(input_sorted, end_row_idx).item()
        input_split_sizes[placement.rank()] = int(end_idx - start_idx)
        input_split_start_indices[placement.rank()] = int(start_idx)
        if placement.rank() != idx:
            rearrange_rows = True

    rearrange_indices_1d_second_order = None
    if rearrange_rows:
        # Need to re-arrange the 1D tensor to be sent via all2all.
        indices: List[List[int]] = [[0]] * world_size
        for placement in weight._sharding_spec.placements:
            split_length = input_split_sizes[placement.rank()]
            offset_idx = input_split_start_indices[placement.rank()]
            indices[placement.rank()] = list(
                range(offset_idx, offset_idx + split_length))
        indices_flatten = list(idx for indice in indices for idx in indice)

        input_sorted = input_sorted.index_select(
            0, torch.tensor(indices_flatten, device=input.device))
        rearrange_indices_1d_second_order = torch.argsort(
            torch.Tensor(indices_flatten))

    # Get the input split size to be sent from each rank to the current rank.
    # We can then infer the output split size.
    input_split_sizes_tensor = (
        torch.Tensor(input_split_sizes).type("torch.IntTensor").cuda(rank))
    output_split_sizes_tensor = torch.empty(world_size,
                                            dtype=torch.int32,
                                            device=input.device)
    dist.all_to_all_single(
        output_split_sizes_tensor,
        input_split_sizes_tensor,
        group=pg,
    )
    output_split_sizes = output_split_sizes_tensor.tolist()

    # Input sent from each rank to the current rank may have different sizes.
    gathered_input = torch.empty(sum(output_split_sizes),
                                 dtype=torch.int64,
                                 device=input.device)

    # Perform the modular operation of the 1D tensor to be sent to each rank.
    input_sorted = torch.remainder(input_sorted, sharded_dim_size_max)

    # Perform alltoall
    dist.all_to_all_single(
        gathered_input,
        input_sorted,
        input_split_sizes=input_split_sizes,
        output_split_sizes=output_split_sizes,
        group=pg,
    )

    # Perform local embedding look up.
    gathered_input_embeddings = torch.nn.functional.embedding(
        gathered_input, local_shard)

    # Gather all lookup result appropriately by performing alltoall again
    gathered_output = torch.empty(input_sorted.size(0),
                                  weight.size(1),
                                  device=input.device)
    dist.all_to_all_single(
        gathered_output,
        gathered_input_embeddings,
        input_split_sizes=output_split_sizes,
        output_split_sizes=input_split_sizes,
        group=pg,
    )

    # Rearrange the results to its original shape.
    if rearrange_indices_1d_second_order is not None:
        gathered_output = gathered_output[rearrange_indices_1d_second_order]
    gathered_output = gathered_output[rearrange_indices_1d]

    # Return the appropriate local result.
    return torch.reshape(gathered_output, (*input_size, weight.size(1)))
Esempio n. 30
0
          ('size', 'min, us', 'avg, us', 'max, us'))

if args.backend != 'mpi':
    dist.init_process_group(args.backend, rank=comm_rank, world_size=comm_size)
else:
    dist.init_process_group(args.backend)

size = args.min_size
while size <= args.max_size:
    bufsize = size * comm_size
    send_tensor = get_tensor(bufsize, args.device, comm_rank)
    recv_tensor = get_tensor(bufsize, args.device, 0)
    time = 0
    for i in range(args.iter + args.skip):
        start = perf_counter()
        req = dist.all_to_all_single(recv_tensor, send_tensor, async_op=True)
        #req = dist.all_reduce(send_tensor, op=dist.ReduceOp.SUM, async_op=True)
        req.wait()
        torch.cuda.synchronize(args.device)
        finish = perf_counter()
        dist.barrier()
        if i > args.skip:
            time += finish - start
    time = [time / args.iter]
    max_time = torch.tensor([time], device=args.device)
    min_time = torch.tensor([time], device=args.device)
    avg_time = torch.tensor([time], device=args.device)
    dist.all_reduce(max_time, op=dist.ReduceOp.MAX)
    dist.all_reduce(min_time, op=dist.ReduceOp.MIN)
    dist.all_reduce(avg_time, op=dist.ReduceOp.SUM)
    if comm_rank == 0: