def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group input = input.contiguous() output = torch.empty_like(input) dist.all_to_all_single(output, input, group=group) return output
def timed_all_to_all(input, output, args): if args.dist == 'torch': import torch.distributed as dist elif args.dist == 'deepspeed': import deepspeed.comm as dist sync_all() # Warmups, establish connections, etc. for i in range(args.warmups): dist.all_to_all_single(output, input, async_op=args.async_op) sync_all() # time the actual comm op trials times and average it pre = time.perf_counter() for i in range(args.trials): dist.all_to_all_single(output, input, async_op=args.async_op) sync_all() duration = time.perf_counter() - pre # maintain and clean performance data avg_duration = duration / args.trials size = input.element_size() * input.nelement() n = dist.get_world_size() tput, busbw = get_bw('all_to_all', size, avg_duration, args) tput_str, busbw_str, duration_str = get_metric_strings(args, tput, busbw, avg_duration) desc = f'{input.nelement()}x{input.element_size()}' if not args.raw: size = convert_size(size) print_rank_0( f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")
def _communicate_size_to_each_rank(input_size_list, output_size, input, pg, tensor_type=torch.int): """ In the circumstance of row-wise sharding of weight, we need to first communicate the input length to each rank because each rank gets a different one. Args: input_size_list: list of sizes to be sent to each rank. output_size: length of the output tensor. input: tensor to be applied op on. pg: process group. tensor_type: dtype of tensor. Return: A list of communication results (int). """ input_size_list_tensor = torch.tensor(input_size_list, dtype=tensor_type, device=input.device) output_size_list_tensor = torch.empty(output_size, dtype=tensor_type, device=input.device) dist.all_to_all_single( output_size_list_tensor, input_size_list_tensor, group=pg, ) return output_size_list_tensor.tolist()
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list( range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select( 0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def _alltoall(self, dispatched_attention): if dist.get_world_size(group=self.ep_group) > 1: dispatched_input = torch.empty_like(dispatched_attention) dist.all_to_all_single(dispatched_input, dispatched_attention, group=self.ep_group) return dispatched_input else: return dispatched_attention
def _result_distribute_with_col_rearrange(results, input, sharding_dim_size, world_size, weight, pg): """ For col-wise sharding of weight, we need to distribute results to each rank. We do them in this function. Note that, if the index in the Sharding Spec is not equal to the rank number, we need to do the rearrangement based on the order given by the Sharding Spec (placement). Args: results: results from ops applied to inputs from all ranks. We need to distribute them back to their original ranks. input: tensor to be applied op to. sharding_dim_size: the max size of the column each rank gets. world_size: number of ranks. weight: shareded weight tensor. pg: process group. Return: column rearranged result. """ # Process results and outputs for all2all. dims = list(results[0].size()) dims[0] = sharding_dim_size output = torch.empty(*dims, device=input.device) combined_results = torch.cat(results) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [0] * world_size for idx, placement in enumerate(weight._sharding_spec.placements): output_split_sizes[placement.rank()] = get_chunked_dim_size( sharding_dim_size, split_size, idx) # distribute the outputs using all2all. dist.all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange columns appropriately for output. rearrange_columns = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if not rearrange_columns: return output indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) return output.index_select(0, torch.tensor(indices, device=output.device))
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[int] = [] for placement in weight._sharding_spec.placements: sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, placement.rank()) input_idx = placement.rank() * split_size indices += range(input_idx, input_idx + sharded_dim_size) input_t = input_t.index_select( 0, torch.tensor(indices, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def wrapper(*args, **kwargs): group = kwargs.get('group', None) async_op = kwargs.get('async_op', False) if (async_op is True): raise RuntimeError('The async_op=True mode is not supported yet.') if (func == dist.all_gather): tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) for i, t in enumerate( _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t elif (func == dist.all_to_all): tensors = args[0] input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) for i, t in enumerate( _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t elif (func == dist.all_to_all_single): tensors = args[0] out_splits = kwargs.get('out_splits', None) in_splits = kwargs.get('in_splits', None) # Quantizing the input/output tensor input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor(tensors, qtype) dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) for i, t in enumerate( _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t else: raise RuntimeError( f"The collective op {func} is not supported yet")
def reshard_flatten_tensor( input_tensor: ShardedTensor, output_spec: ShardingSpec, world_size: int, my_rank: int, device: torch.device, process_group: Optional[dist.ProcessGroup], ) -> torch.Tensor: """ Resharded a sharded flatten tensor, this is used by FSDP to do sharded state_dict. But the functionaility is not supported by ShardedTensor. This API is designed to be used for FSDP; therefore this API supports only 1-D ShardedTensor (hence the naming, reshard_flatten_tensor). This API uses the ChunkShardingSpec and EnumerableShardingSpec from torch.distributed.sharding_spec but ignores the placement field in ChunkShardingSpec, as the placement requires the callees understand the number of GPUs per node. The API simply uses the semantics of the sharding specs. Args: input_tensor (ShardedTensor): the original ShardedTensor. Must be 1D. output_spec (ShardingSpec): the sharding spect for the output tensor. world_size (int): total trainer count. my_rank (int): the rank for this trainer. Returns: The local shard for the new ShardedTensor. """ input_spec = input_tensor.sharding_spec() size = input_tensor.size() if isinstance(size, int): raise ValueError("The input tensor has no dimensions.") tensor_numel = size.numel() input_offsets = _sharding_spec_to_offsets(input_spec, tensor_numel, world_size) output_offsets = _sharding_spec_to_offsets(output_spec, tensor_numel, world_size) input_split_sizes, output_split_sizes = _offsets_to_split_sizes( input_offsets, output_offsets, tensor_numel, world_size, my_rank) output_size = sum(output_split_sizes) local_shard = torch.empty(output_size, dtype=input_tensor.dtype, device=device) dist.all_to_all_single( local_shard, input_tensor.local_shards()[0].tensor, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=process_group, ) return local_shard
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): ctx.group = group ctx.input_size = input.size() ctx.output_split_sizes = input_split_sizes ctx.input_split_sizes = output_split_sizes dist.all_to_all_single( output, input, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, ) return output
def _handle_col_wise_sharding(input, world_size, weight, local_shard_t, bias, pg): # allgather the inputs first. gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)] dist.all_gather(gathered_inputs, input, group=pg) # matmul all the inputs. results = [] for i, inp in enumerate(gathered_inputs): results.append(inp.matmul(local_shard_t).t()) # Process inputs and outputs for all2all. sharding_dim_size = weight.size()[0] output = torch.empty((sharding_dim_size, input.size(0)), device=input.device) combined_results = torch.cat(results) # Compute output splits split_size = get_split_size(sharding_dim_size, world_size) output_split_sizes = [ get_chunked_dim_size(sharding_dim_size, split_size, placement.rank()) for placement in weight._sharding_spec.placements ] # distribute the outputs using all2all. dist.all_to_all_single(output, combined_results, output_split_sizes=output_split_sizes, group=pg) # Check if we need to rearrange rows appropriately for output. rearrange_rows = any([ idx != placement.rank() for idx, placement in enumerate(weight._sharding_spec.placements) ]) if rearrange_rows: indices = [] for placement in weight._sharding_spec.placements: dim_size = output_split_sizes[placement.rank()] start = sum([ split_size if i < placement.rank() else 0 for i, split_size in enumerate(output_split_sizes) ]) indices += list(range(start, start + dim_size)) output = output.index_select( 0, torch.tensor(indices, device=output.device)) # add bias and return result. return output.t() + bias
def forward(ctx, a2ai, *inputs): global myreq mb_split_lengths = a2ai.gNS if mb_split_lengths: mb_split_lengths = [m * a2ai.lS * a2ai.E for m in mb_split_lengths] emb_split_lengths = a2ai.gSS if emb_split_lengths: emb_split_lengths = [ a2ai.lN * e * a2ai.E for e in emb_split_lengths ] input = torch.cat(inputs, dim=1).view([-1]) output = input.new_empty([a2ai.S * a2ai.lN * a2ai.E]) req = dist.all_to_all_single(output, input, emb_split_lengths, mb_split_lengths, async_op=True) myreq.req = req myreq.tensor = [] myreq.tensor.append(output) myreq.tensor = tuple(myreq.tensor) a2ai.mb_split_lengths = mb_split_lengths a2ai.emb_split_lengths = emb_split_lengths myreq.a2ai = a2ai ctx.a2ai = a2ai return myreq.tensor
def alltoall_test(): global comm_rank, comm_size dev = torch.device('cuda') t_send = torch.zeros(comm_size, device=dev) + comm_rank t_recv = torch.zeros(comm_size, device=dev) dist.all_to_all_single(t_recv, t_send) t_recv = t_recv + 1 dist.all_reduce(t_recv) if torch.all( torch.eq( t_recv, comm_size * torch.arange(start=1, end=comm_size + 1, device=dev))): print(f"Rank {comm_rank}: success") else: print(f"Rank {comm_rank}: failed")
def all_to_allv(self, collectiveArgs, retFlag=False, pair=False): # pair=True mode does not support quantization if ( collectiveArgs.all2all_qcomm and collectiveArgs.ipTensor.dtype == torch.float32 and ( collectiveArgs.opTensor.nelement() >= collectiveArgs.quan_threshold or collectiveArgs.ipTensor.nelement() >= collectiveArgs.quan_threshold ) and not pair ): work = all_to_allv_internal(collectiveArgs) else: work = dist.all_to_all_single( collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, collectiveArgs.opTensor_split if not pair else collectiveArgs.opTensor_split_pair, collectiveArgs.ipTensor_split if not pair else collectiveArgs.ipTensor_split_pair, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) if retFlag: return work
def forward(ctx, a2a_info, *inputs): global myreq with record_function("DLRM alltoall_req_fwd_single"): batch_split_lengths = a2a_info.global_batch_partition_slices if batch_split_lengths: batch_split_lengths = [ m * a2a_info.emb_dim * a2a_info.local_table_num for m in batch_split_lengths ] table_split_lengths = a2a_info.global_table_wise_parition_slices if table_split_lengths: table_split_lengths = [ a2a_info.local_batch_num * e * a2a_info.emb_dim for e in table_split_lengths ] input = torch.cat(inputs, dim=1).view([-1]) output = input.new_empty([ a2a_info.global_table_num * a2a_info.local_batch_num * a2a_info.emb_dim ]) req = dist.all_to_all_single(output, input, table_split_lengths, batch_split_lengths, async_op=True) myreq.req = req myreq.tensor = [] myreq.tensor.append(output) myreq.tensor = tuple(myreq.tensor) a2a_info.batch_split_lengths = batch_split_lengths a2a_info.table_split_lengths = table_split_lengths myreq.a2a_info = a2a_info ctx.a2a_info = a2a_info return myreq.tensor
def all_to_all(tensor, group): """Perform an all-to-all operation on a 1D Tensor.""" assert tensor.dim() == 1 split_count = get_world_size(group=group) assert tensor.numel() % split_count == 0 if use_xla(): assert isinstance(group, tuple) and group[0] == "tpu" return xm.all_to_all( tensor, split_dimension=0, concat_dimension=0, split_count=split_count, groups=group[1], ) else: output = torch.zeros_like(tensor) dist.all_to_all_single(output, tensor, group=group) return output
def all_to_all(self, collectiveArgs, retFlag=False): retObj = dist.all_to_all_single( collectiveArgs.opTensor, collectiveArgs.ipTensor, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) # synchronicity is maintained in runColl if retFlag: return retObj else: return
def backward(ctx, *grad_outputs): global myreq #print("All2All_Wait:backward") a2ai = ctx.a2ai grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs] grad_output = torch.cat(grad_outputs) grad_input = grad_output.new_empty([a2ai.N * a2ai.lS * a2ai.E]) req = dist.all_to_all_single(grad_input, grad_output, a2ai.mb_split_lengths, a2ai.emb_split_lengths, async_op=True) myreq.req = req myreq.tensor = grad_input return (grad_output,)
def all_to_allv(self, collectiveArgs, retFlag=False): if collectiveArgs.all2all_qcomm: work = all_to_allv_internal(collectiveArgs) else: work = dist.all_to_all_single( collectiveArgs.opTensor, collectiveArgs.ipTensor, collectiveArgs.opTensor_split, collectiveArgs.ipTensor_split, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) if retFlag: return work
def all_to_all(self, collectiveArgs: collectiveArgsHolder, retFlag=False): if collectiveArgs.all2all_qcomm: work = all_to_all_internal(collectiveArgs) else: work = dist.all_to_all_single( collectiveArgs.opTensor, collectiveArgs.ipTensor, None, None, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) if retFlag: return work
def backward(ctx, *grad_outputs): global myreq with record_function("DLRM alltoall_wait_bwd_single"): a2a_info = ctx.a2a_info grad_outputs = [gout.contiguous().view([-1]) for gout in grad_outputs] grad_output = torch.cat(grad_outputs) grad_input = grad_output.new_empty( [a2a_info.batch_size * a2a_info.local_table_num * a2a_info.emb_dim] ) req = dist.all_to_all_single( grad_input, grad_output, a2a_info.batch_split_lengths, a2a_info.table_split_lengths, async_op=True, ) myreq.req = req myreq.tensor = grad_input return (grad_output,)
def all_to_all(self, collectiveArgs: collectiveArgsHolder, retFlag=False, pair=False): # pair=True mode does not support quantization if collectiveArgs.all2all_qcomm and not pair: work = all_to_all_internal(collectiveArgs) else: work = dist.all_to_all_single( collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, None, None, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) if retFlag: return work
def init_distributed(rank=-1, local_rank=-1, size=-1, use_gpu=False, backend=""): global myreq global my_rank global my_size global my_local_rank global my_local_size global a2a_impl global alltoall_supported # guess MPI ranks from env (works for IMPI, OMPI and MVAPICH2) num_mpi_ranks = env2int([ "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE" ]) if backend == "" and num_mpi_ranks > 1: if torch_ccl and env2int(["CCL_WORKER_COUNT"]) > 0: backend = "ccl" elif use_gpu and dist.is_nccl_available(): backend = "nccl" elif dist.is_mpi_available(): backend = "mpi" else: print( "WARNING: MPI multi-process launch detected but PyTorch MPI backend not available." ) backend = "gloo" if backend != "": # guess Rank and size if rank == -1: rank = env2int([ "PMI_RANK", "OMPI_COMM_WORLD_RANK", "MV2_COMM_WORLD_RANK", "RANK" ], 0) if size == -1: size = env2int( [ "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "MV2_COMM_WORLD_SIZE", "WORLD_SIZE", ], 1, ) if not os.environ.get("RANK", None) and rank != -1: os.environ["RANK"] = str(rank) if not os.environ.get("WORLD_SIZE", None) and size != -1: os.environ["WORLD_SIZE"] = str(size) if not os.environ.get("MASTER_PORT", None): os.environ["MASTER_PORT"] = "29500" if not os.environ.get("MASTER_ADDR", None): local_size = env2int( [ "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE", ], 1, ) if local_size != size and backend != "mpi": print( "Warning: Looks like distributed multinode run but MASTER_ADDR env not set, using '127.0.0.1' as default" ) print( "If this run hangs, try exporting rank 0's hostname as MASTER_ADDR" ) os.environ["MASTER_ADDR"] = "127.0.0.1" if size > 1: if local_rank == -1: my_local_rank = env2int( [ "MPI_LOCALRANKID", "OMPI_COMM_WORLD_LOCAL_RANK", "MV2_COMM_WORLD_LOCAL_RANK", "LOCAL_RANK", ], 0, ) else: my_local_rank = local_rank my_local_size = env2int( [ "MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE", ], 1, ) if use_gpu: if my_local_size > torch.cuda.device_count(): print( "Not sufficient GPUs available... local_size = %d, ngpus = %d" % (my_local_size, torch.cuda.device_count())) sys.exit(1) torch.cuda.set_device(my_local_rank) dist.init_process_group(backend, rank=rank, world_size=size) my_rank = dist.get_rank() my_size = dist.get_world_size() if my_rank == 0: print("Running on %d ranks using %s backend" % (my_size, backend)) if hasattr(dist, "all_to_all_single"): try: t = torch.zeros([4]) if use_gpu: t = t.cuda() dist.all_to_all_single(t, t) alltoall_supported = True except RuntimeError as err: print("fail to enable all_to_all_single primitive: %s" % err) if a2a_impl == "alltoall" and alltoall_supported == False: print( "Requested DLRM_ALLTOALL_IMPL=%s but backend %s does not support it, use scatter/gather based alltoall" % (a2a_impl, backend)) a2a_impl = "scatter" if a2a_impl != "": print("Using DLRM_ALLTOALL_IMPL=%s" % a2a_impl) else: my_rank = 0 my_size = 1 my_local_rank = 0 my_local_size = 1 print_all("world size: %d, current rank: %d, local rank: %d" % (my_size, my_rank, my_local_rank)) myreq = Request()
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard_t, bias, pg): """ Entry-point function to handle the logic of row-wise sharding of weight for Linear. (Detailed explanations of the logic can be found in the comment for sharded_linear.) Args: input: matrix to be multiplied with the sharded weight. world_size: number of ranks. weight: shareded weight tensor. rank: # of cuda process. local_shard_t: row-wise shared local weight used for lookup. bias: bias term of linear op. pg: process group. Returns: final result of linear operation. """ # alltoall to gather all the appropriate inputs. input_t = input.t().contiguous() input_t_size = input_t.size() # Compute expected size split_size = get_split_size(input_t_size[0], world_size) input_split_sizes = [0] * world_size rearrange_rows = False for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(input_t_size[0], split_size, idx) input_split_sizes[placement.rank()] = sharded_dim_size if placement.rank() != idx: rearrange_rows = True if rearrange_rows: # Need to re-arrange rows of input_t for all2all. indices: List[List[int]] = [[0]] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = max(input_split_sizes) for idx, placement in enumerate(weight._sharding_spec.placements): split_size = input_split_sizes[placement.rank()] offset_start_idx = idx * sharded_dim_size_max indices[placement.rank()] = list(range(offset_start_idx, offset_start_idx + split_size)) indices_flatten = list(idx for indice in indices for idx in indice) input_t = input_t.index_select(0, torch.tensor(indices_flatten, device=input_t.device)) gathered_input = torch.empty(input_split_sizes[rank] * world_size, input_t_size[1], device=input_t.device) # Perform alltoall dist.all_to_all_single(gathered_input, input_t, input_split_sizes=input_split_sizes, group=pg) gathered_input = gathered_input.t() # Perform local matmuls for all shards shard_size = local_shard_t.size()[0] results = [] for r in range(world_size): inp = torch.narrow(gathered_input, 1, r * shard_size, shard_size) results.append(inp.matmul(local_shard_t)) # Gather all the results appropriately. local_result = torch.empty_like(results[rank]) dist.reduce_scatter(local_result, results, group=pg) # Return the appropriate local result. return local_result + bias
def compressed_allreduce(self, buffer_m: torch.tensor, worker_error, server_error, local_rank): # all_start_time = time.time() original_shape = buffer_m.size() if len(original_shape) > 1: buffer_m = torch.flatten(buffer_m) original_size = buffer_m.numel() worker_error_size = worker_error.numel() cupy.cuda.Device(local_rank).use() if original_size != worker_error_size: empty_tensor = torch.zeros(worker_error_size - original_size, device=buffer_m.device) buffer_m = torch.cat([buffer_m, empty_tensor]) buffer_m.add_(worker_error) worker_scale = torch.norm(buffer_m) / np.sqrt(torch.numel(buffer_m)) worker_error.set_( buffer_m - worker_scale * buffer_m.sign().add_(1).bool().float().add_(-0.5).mul_(2.0)) cupy_sign_list_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy( buffer_m.sign_().add_(1).bool()), self.size) cupy_worker_scale = self.compression_backend.torch2cupy(worker_scale) cupy_recvbuf_sign = cupy.zeros( [self.size, cupy_sign_list_packed[self.rank].size], dtype=cupy_sign_list_packed[0].dtype) # cupy_recvbuf_scale = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) sign_list_packed = [ self.compression_backend.cupy2torch(cupy_sign_list_packed[idx]) for idx in range(self.size) ] # worker_scale = self.compression_backend.cupy2torch(cupy_worker_scale) recvbuf_sign = self.compression_backend.cupy2torch(cupy_recvbuf_sign) #recvbuf_scale = self.compression_backend.cupy2torch(cupy_recvbuf_scale) recvbuf_scale = [ torch.zeros(1, dtype=worker_scale.dtype, device=torch.device(local_rank)) for i in range(self.size) ] # communication phase 1 # gather_start = time.time() # Alltoall for sign dist.all_to_all_single(recvbuf_sign, torch.stack(sign_list_packed), group=self.world_group) # Allgather for scale dist.all_gather(recvbuf_scale, worker_scale, group=self.world_group) # gather_end = time.time() # cupy_sign_list_packed, sign_list_packed, cupy_worker_scale, worker_scale = None, None, None, None cupy_sign_list_packed = None cupy_recvbuf_sign = self.compression_backend.torch2cupy(recvbuf_sign) #cupy_recvbuf_scale = self.compression_backend.torch2cupy(torch.stack(recvbuf_scale)) compensated_server_m = self.compression_backend.cupy2torch( (cupy.unpackbits(cupy_recvbuf_sign.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( torch.stack(recvbuf_scale).mul_(1 / self.size)).sum(0) compensated_server_m.add_(server_error) server_scale = torch.norm(compensated_server_m) / np.sqrt( compensated_server_m.numel()) server_error.set_(compensated_server_m - server_scale * compensated_server_m.sign().add_( 1).bool().float().add_(-0.5).mul_(2.0)) # cupy_server_scale = self.compression_backend.torch2cupy(server_scale) cupy_server_sign_packed = self.compression_backend.compress_by_chunk( self.compression_backend.torch2cupy( compensated_server_m.sign_().add_(1).bool()), 1) compensated_server_m = None cupy_recvbuf_sign_server = cupy.zeros( [self.size, cupy_server_sign_packed[0].size], dtype=cupy_recvbuf_sign.dtype) # cupy_recvbuf_sign, recvbuf_sign = None, None cupy_recvbuf_sign = None server_sign_packed = [ self.compression_backend.cupy2torch(cupy_server_sign_packed[0]) ] recvbuf_sign_server = [ self.compression_backend.cupy2torch(cupy_recvbuf_sign_server[idx]) for idx in range(self.size) ] # server_scale = self.compression_backend.cupy2torch(cupy_server_scale) cupy_recvbuf_scale_server = cupy.zeros([self.size, 1], dtype=cupy_worker_scale.dtype) # cupy_recvbuf_scale, recvbuf_scale = None, None recvbuf_scale_server = [ self.compression_backend.cupy2torch(cupy_recvbuf_scale_server[idx]) for idx in range(self.size) ] # Communication Phase 2 dist.all_gather(recvbuf_sign_server, server_sign_packed[0], group=self.world_group) dist.all_gather(recvbuf_scale_server, server_scale, group=self.world_group) cupy_server_sign_packed = None # need to convert from a tensor list to a single tensor # dist.all_gather only provides a tensor list as the recv/output buffer recvbuf_sign_server = torch.stack(recvbuf_sign_server) cupy_recvbuf_sign_server = self.compression_backend.torch2cupy( recvbuf_sign_server) buffer_m.data.copy_( self.compression_backend.cupy2torch( (cupy.unpackbits(cupy_recvbuf_sign_server.flatten())).reshape( self.size, -1)).float().add_(-0.5).mul_(2.0).mul_( self.compression_backend.cupy2torch( cupy_recvbuf_scale_server)).flatten().data) if original_size != worker_error_size: buffer_m = buffer_m[0:original_size] if len(original_shape) > 1: buffer_m = buffer_m.reshape(original_shape) return buffer_m
t1 = t1.cuda() t2 = t2.cuda() if args.op == "p2p": if rank == 0: dist.send(t, 1) else: dist.recv(t, 0) elif args.op == "broadcast": dist.broadcast(t, 0) elif args.op == "allreduce": dist.all_reduce(t, op=dist.ReduceOp.SUM) elif args.op == "reduce": dist.reduce(t, 0, op=dist.ReduceOp.SUM) elif args.op == "alltoall": dist.all_to_all_single(t2, t) elif args.op == "alltoallv": out_split = [1] * size in_split = [1] * size dist.all_to_all_single(t2, t, out_split, in_split) elif args.op == "allgather": dist.all_gather([t1, t2], t) else: print("Incorrect operation") sys.exit(1) #dist.barrier() print('rank ', rank, ':', t, ":", t1, ":", t2) dist.destroy_process_group()
def shuffle_data(inputs): input = torch.cat(inputs) output = input.new_empty(input.size()) req = dist.all_to_all_single(output, input) output = output.reshape(my_size, -1) return output
def _handle_row_wise_sharding( input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg ): """ Entry-point function to handle the logic of row-wise sharding of weight for embedding. (Detailed explanations of the logic can be found in the comment for sharded_embedding.) Args: input: list of ID used for lookup and aggregation. world_size: number of ranks. weight: shareded weight tensor. local_shard: row-wise shared local weight used for lookup. max_norm: If given, each embedding vector with norm larger than max_norm is renormalized to have norm max_norm. Note: this will modify weight in-place. norm_type: The p in the p-norm to compute for the max_norm option. padding_idx: If specified, the entries at padding_idx do not contribute to the gradient; therefore, the embedding vector at padding_idx is not updated during training, i.e. it remains as a fixed “pad”. rank: # of cuda process. pg: process group. Returns: final result of lookup. """ # flatten the ids across all input and sort input_size = input.size() input_1d = torch.reshape(input, (-1,)).contiguous() input_sorted, indices_1d = torch.sort(input_1d) rearrange_indices_1d = torch.argsort(indices_1d) input_sorted.contiguous() ( input_sorted, input_split_sizes, sharded_dim_size_max, _, rearrange_indices_1d_second_order, padding_idx, ) = _handle_row_wise_lookup_distribute( input_sorted, input, world_size, weight, rank, padding_idx ) # Get the input split size to be sent from each rank to the current rank. # We can then infer the output split size. output_split_sizes = _communicate_size_to_each_rank( input_split_sizes, world_size, input, pg ) # Input sent from each rank to the current rank may have different sizes. gathered_input = torch.empty( sum(output_split_sizes), dtype=torch.int64, device=input.device ) # Perform the modular operation of the 1D tensor to be sent to each rank. input_sorted = torch.remainder(input_sorted, sharded_dim_size_max) # Perform alltoall dist.all_to_all_single( gathered_input, input_sorted, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=pg, ) # If input is None, passing in max_norm causes # errors in CUDA. if max_norm is not None and gathered_input.size(0) == 0: max_norm = None # Perform local embedding look up. gathered_input_embeddings = torch.nn.functional.embedding( gathered_input, local_shard, padding_idx=padding_idx, max_norm=max_norm, norm_type=norm_type, ) # Gather all lookup result appropriately by performing alltoall again gathered_output = torch.empty( input_sorted.size(0), weight.size(1), device=input.device ) dist.all_to_all_single( gathered_output, gathered_input_embeddings, input_split_sizes=output_split_sizes, output_split_sizes=input_split_sizes, group=pg, ) # Rearrange the results to its original shape. if rearrange_indices_1d_second_order is not None: gathered_output = gathered_output[rearrange_indices_1d_second_order] gathered_output = gathered_output[rearrange_indices_1d] # Return the appropriate local result. return torch.reshape(gathered_output, (*input_size, weight.size(1)))
def _handle_row_wise_sharding(input, world_size, weight, rank, local_shard, pg): # flatten the ids across all input and sort input_size = input.size() input_1d = torch.reshape(input, (-1, )).contiguous() input_sorted, indices_1d = torch.sort(input_1d) rearrange_indices_1d = torch.argsort(indices_1d) input_sorted.contiguous() # Decide which rank the input goes to by check the sharding range. split_size = get_split_size(weight.size(0), world_size) rearrange_rows = False input_split_sizes: List[int] = [0] * world_size input_split_start_indices: List[int] = [0] * world_size # When we do the chunk split, we always ensure the first N - 1 chunks get max out # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4] # are not possible. The expected split size will be [4, 4, 4, 1]. sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0) for idx, placement in enumerate(weight._sharding_spec.placements): sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx) start_row_idx = idx * sharded_dim_size_max end_row_idx = start_row_idx + sharded_dim_size start_idx = torch.searchsorted(input_sorted, start_row_idx).item() end_idx = torch.searchsorted(input_sorted, end_row_idx).item() input_split_sizes[placement.rank()] = int(end_idx - start_idx) input_split_start_indices[placement.rank()] = int(start_idx) if placement.rank() != idx: rearrange_rows = True rearrange_indices_1d_second_order = None if rearrange_rows: # Need to re-arrange the 1D tensor to be sent via all2all. indices: List[List[int]] = [[0]] * world_size for placement in weight._sharding_spec.placements: split_length = input_split_sizes[placement.rank()] offset_idx = input_split_start_indices[placement.rank()] indices[placement.rank()] = list( range(offset_idx, offset_idx + split_length)) indices_flatten = list(idx for indice in indices for idx in indice) input_sorted = input_sorted.index_select( 0, torch.tensor(indices_flatten, device=input.device)) rearrange_indices_1d_second_order = torch.argsort( torch.Tensor(indices_flatten)) # Get the input split size to be sent from each rank to the current rank. # We can then infer the output split size. input_split_sizes_tensor = ( torch.Tensor(input_split_sizes).type("torch.IntTensor").cuda(rank)) output_split_sizes_tensor = torch.empty(world_size, dtype=torch.int32, device=input.device) dist.all_to_all_single( output_split_sizes_tensor, input_split_sizes_tensor, group=pg, ) output_split_sizes = output_split_sizes_tensor.tolist() # Input sent from each rank to the current rank may have different sizes. gathered_input = torch.empty(sum(output_split_sizes), dtype=torch.int64, device=input.device) # Perform the modular operation of the 1D tensor to be sent to each rank. input_sorted = torch.remainder(input_sorted, sharded_dim_size_max) # Perform alltoall dist.all_to_all_single( gathered_input, input_sorted, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes, group=pg, ) # Perform local embedding look up. gathered_input_embeddings = torch.nn.functional.embedding( gathered_input, local_shard) # Gather all lookup result appropriately by performing alltoall again gathered_output = torch.empty(input_sorted.size(0), weight.size(1), device=input.device) dist.all_to_all_single( gathered_output, gathered_input_embeddings, input_split_sizes=output_split_sizes, output_split_sizes=input_split_sizes, group=pg, ) # Rearrange the results to its original shape. if rearrange_indices_1d_second_order is not None: gathered_output = gathered_output[rearrange_indices_1d_second_order] gathered_output = gathered_output[rearrange_indices_1d] # Return the appropriate local result. return torch.reshape(gathered_output, (*input_size, weight.size(1)))
('size', 'min, us', 'avg, us', 'max, us')) if args.backend != 'mpi': dist.init_process_group(args.backend, rank=comm_rank, world_size=comm_size) else: dist.init_process_group(args.backend) size = args.min_size while size <= args.max_size: bufsize = size * comm_size send_tensor = get_tensor(bufsize, args.device, comm_rank) recv_tensor = get_tensor(bufsize, args.device, 0) time = 0 for i in range(args.iter + args.skip): start = perf_counter() req = dist.all_to_all_single(recv_tensor, send_tensor, async_op=True) #req = dist.all_reduce(send_tensor, op=dist.ReduceOp.SUM, async_op=True) req.wait() torch.cuda.synchronize(args.device) finish = perf_counter() dist.barrier() if i > args.skip: time += finish - start time = [time / args.iter] max_time = torch.tensor([time], device=args.device) min_time = torch.tensor([time], device=args.device) avg_time = torch.tensor([time], device=args.device) dist.all_reduce(max_time, op=dist.ReduceOp.MAX) dist.all_reduce(min_time, op=dist.ReduceOp.MIN) dist.all_reduce(avg_time, op=dist.ReduceOp.SUM) if comm_rank == 0: