def _communicate_list_to_each_rank(input_tensor_list, output_lists, input, pg, tensor_type=torch.int64): """ In the circumstance of row-wise sharding of weight, we need to communicate a list of input tensors to each rank. Because the input could be a list of list, we need to first convert the list to a tensor. Args: input_tensor_list: list of tensors to be sent to each rank. output_lists: list of sizes to be obtained from each rank. input: tensor to be applied op on. pg: process group. tensor_type: dtype of tensor. Return: A list of communication results (tensors). """ output_tensor_list = [] for output_list in output_lists: output_tensor_list.append( torch.empty(output_list, dtype=tensor_type, device=input.device)) dist.all_to_all( output_tensor_list, input_tensor_list, group=pg, ) return output_tensor_list
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: # type: ignore ctx.group = group world_size = dist.get_world_size(group) input = input.contiguous() output = torch.empty_like(input) input_chunks = list(input.chunk(world_size)) output_chunks = list(output.chunk(world_size)) dist.all_to_all(output_chunks, input_chunks, group=group) return output
def wrapper(*args, **kwargs): group = kwargs.get('group', None) async_op = kwargs.get('async_op', False) if (async_op is True): raise RuntimeError('The async_op=True mode is not supported yet.') if (func == dist.all_gather): tensors = args[0] input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op) for i, t in enumerate( _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t elif (func == dist.all_to_all): tensors = args[0] input_tensors = _quantize_tensor_list(args[1], qtype) out_tensors = _quantize_tensor_list(tensors, qtype) dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op) for i, t in enumerate( _dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t elif (func == dist.all_to_all_single): tensors = args[0] out_splits = kwargs.get('out_splits', None) in_splits = kwargs.get('in_splits', None) # Quantizing the input/output tensor input_tensors = _quantize_tensor(args[1], qtype) out_tensors = _quantize_tensor(tensors, qtype) dist.all_to_all_single(out_tensors, input_tensors, out_splits, in_splits, group=group) for i, t in enumerate( _dequantize_tensor(out_tensors, qtype, quant_loss=quant_loss)): tensors[i] = t else: raise RuntimeError( f"The collective op {func} is not supported yet")
def forward(ctx, group, *tensors): ctx.group = group out_tensor_list = [ torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group)) ] reqs = [None] * dist.get_world_size(group=group) my_rank = dist.get_rank(group=group) # Implement it on means of scatter/gather, send/recv async operations have issues if dist.get_backend(group=group) is dist.Backend.GLOO: for i in range(dist.get_world_size(group=group)): to_send = None if i == my_rank: to_send = list(tensors) dist.scatter(out_tensor_list[i], to_send, i, group=group) else: dist.all_to_all(out_tensor_list, list(tensors), group=group) return tuple(out_tensor_list)
def forward(ctx, group, out_tensor_list, *tensors): ctx.group = group ctx.input_tensor_size_list = [ tensors[i].size() for i in range(dist.get_world_size(group=group)) ] my_rank = dist.get_rank(group=group) tensors = tuple(t.contiguous() for t in tensors) # Implement it on means of scatter/gather, send/recv async operations have issues if dist.get_backend(group=group) is dist.Backend.GLOO: for i in range(dist.get_world_size(group=group)): to_send = None if i == my_rank: to_send = list(tensors) dist.scatter(out_tensor_list[i], to_send, i, group=group) else: dist.all_to_all( out_tensor_list, list(tensors), group=group, ) return tuple(out_tensor_list)
def all_to_all( self, collectiveArgs: collectiveArgsHolder, retFlag=False, pair=False ): # pair=True mode does not support quantization if collectiveArgs.all2all_qcomm and not pair: work = all_to_all_internal(collectiveArgs) else: work = dist.all_to_all( collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair, collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair, group=collectiveArgs.group, async_op=collectiveArgs.asyncOp, ) if collectiveArgs.asyncOp: collectiveArgs.waitObj.append(work) if retFlag: return work
def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor: expert_output = input.contiguous() chunks = list(expert_output.chunk(self.world_size)) dist.all_to_all(chunks, chunks, self.group) output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output) return output
def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor: dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input) dispatched_input = dispatched_input.contiguous() chunks = list(dispatched_input.chunk(self.world_size)) dist.all_to_all(chunks, chunks, self.group) return dispatched_input
print_all("rank = %d, size = %d" % (rank, size)) print_all("all_to_all_single with empty tensors") dist.all_to_all_single(torch.empty([0]), torch.empty([0])) dist.barrier() print_all("scatter using alltoall") if rank == 1: x = [torch.ones([2]) * (r + 1) for r in range(size)] else: x = [torch.zeros([0]) for _ in range(size)] y = [torch.zeros([2]) if r == 1 else torch.zeros([0]) for r in range(size)] print_all("x = %s" % x) print_all("y = %s" % y) dist.all_to_all(y, x) print_all("y = %s" % y) dist.barrier() print_all("gather using alltoall") x = [ torch.ones([2]) * (rank + 1) if r == 1 else torch.zeros([0]) for r in range(size) ] if rank == 1: y = [torch.zeros([2]) for _ in range(size)] else: y = [torch.zeros([0]) for _ in range(size)] print_all("x = %s" % x) print_all("y = %s" % y)