示例#1
0
def _communicate_list_to_each_rank(input_tensor_list,
                                   output_lists,
                                   input,
                                   pg,
                                   tensor_type=torch.int64):
    """
    In the circumstance of row-wise sharding of weight, we need to
    communicate a list of input tensors to each rank. Because the
    input could be a list of list, we need to first convert the list
    to a tensor.

    Args:
        input_tensor_list: list of tensors to be sent to each rank.
        output_lists: list of sizes to be obtained from each rank.
        input: tensor to be applied op on.
        pg: process group.
        tensor_type: dtype of tensor.

    Return: A list of communication results (tensors).
    """
    output_tensor_list = []
    for output_list in output_lists:
        output_tensor_list.append(
            torch.empty(output_list, dtype=tensor_type, device=input.device))
    dist.all_to_all(
        output_tensor_list,
        input_tensor_list,
        group=pg,
    )
    return output_tensor_list
示例#2
0
 def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor:  # type: ignore
     ctx.group = group
     world_size = dist.get_world_size(group)
     input = input.contiguous()
     output = torch.empty_like(input)
     input_chunks = list(input.chunk(world_size))
     output_chunks = list(output.chunk(world_size))
     dist.all_to_all(output_chunks, input_chunks, group=group)
     return output
示例#3
0
    def wrapper(*args, **kwargs):
        group = kwargs.get('group', None)
        async_op = kwargs.get('async_op', False)
        if (async_op is True):
            raise RuntimeError('The async_op=True mode is not supported yet.')
        if (func == dist.all_gather):
            tensors = args[0]
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_gather(out_tensors,
                            input_tensors,
                            group=group,
                            async_op=async_op)
            for i, t in enumerate(
                    _dequantize_tensor_list(out_tensors,
                                            qtype,
                                            quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all):
            tensors = args[0]
            input_tensors = _quantize_tensor_list(args[1], qtype)
            out_tensors = _quantize_tensor_list(tensors, qtype)
            dist.all_to_all(out_tensors,
                            input_tensors,
                            group=group,
                            async_op=async_op)
            for i, t in enumerate(
                    _dequantize_tensor_list(out_tensors,
                                            qtype,
                                            quant_loss=quant_loss)):
                tensors[i] = t

        elif (func == dist.all_to_all_single):
            tensors = args[0]
            out_splits = kwargs.get('out_splits', None)
            in_splits = kwargs.get('in_splits', None)
            # Quantizing the input/output tensor
            input_tensors = _quantize_tensor(args[1], qtype)
            out_tensors = _quantize_tensor(tensors, qtype)
            dist.all_to_all_single(out_tensors,
                                   input_tensors,
                                   out_splits,
                                   in_splits,
                                   group=group)
            for i, t in enumerate(
                    _dequantize_tensor(out_tensors,
                                       qtype,
                                       quant_loss=quant_loss)):
                tensors[i] = t
        else:
            raise RuntimeError(
                f"The collective op {func} is not supported yet")
示例#4
0
 def forward(ctx, group, *tensors):
     ctx.group = group
     out_tensor_list = [
         torch.empty_like(tensors[i]) for i in range(dist.get_world_size(group=group))
     ]
     reqs = [None] * dist.get_world_size(group=group)
     my_rank = dist.get_rank(group=group)
     # Implement it on means of scatter/gather, send/recv async operations have issues
     if dist.get_backend(group=group) is dist.Backend.GLOO:
         for i in range(dist.get_world_size(group=group)):
             to_send = None
             if i == my_rank:
                 to_send = list(tensors)
             dist.scatter(out_tensor_list[i], to_send, i, group=group)
     else:
         dist.all_to_all(out_tensor_list, list(tensors), group=group)
     return tuple(out_tensor_list)
示例#5
0
 def forward(ctx, group, out_tensor_list, *tensors):
     ctx.group = group
     ctx.input_tensor_size_list = [
         tensors[i].size() for i in range(dist.get_world_size(group=group))
     ]
     my_rank = dist.get_rank(group=group)
     tensors = tuple(t.contiguous() for t in tensors)
     # Implement it on means of scatter/gather, send/recv async operations have issues
     if dist.get_backend(group=group) is dist.Backend.GLOO:
         for i in range(dist.get_world_size(group=group)):
             to_send = None
             if i == my_rank:
                 to_send = list(tensors)
             dist.scatter(out_tensor_list[i], to_send, i, group=group)
     else:
         dist.all_to_all(
             out_tensor_list,
             list(tensors),
             group=group,
         )
     return tuple(out_tensor_list)
    def all_to_all(
        self, collectiveArgs: collectiveArgsHolder, retFlag=False, pair=False
    ):
        # pair=True mode does not support quantization
        if collectiveArgs.all2all_qcomm and not pair:
            work = all_to_all_internal(collectiveArgs)
        else:
            work = dist.all_to_all(
                collectiveArgs.opTensor if not pair else collectiveArgs.opTensor_pair,
                collectiveArgs.ipTensor if not pair else collectiveArgs.ipTensor_pair,
                group=collectiveArgs.group,
                async_op=collectiveArgs.asyncOp,
            )

        if collectiveArgs.asyncOp:
            collectiveArgs.waitObj.append(work)

        if retFlag:
            return work
示例#7
0
 def all_to_all_combine(self, combine_weights: Tensor, input: Tensor) -> Tensor:
     expert_output = input.contiguous()
     chunks = list(expert_output.chunk(self.world_size))
     dist.all_to_all(chunks, chunks, self.group)
     output = torch.einsum("gsec,egcm->gsm", combine_weights, expert_output)
     return output
示例#8
0
 def all_to_all_dispatch(self, dispatch_mask: Tensor, input: Tensor) -> Tensor:
     dispatched_input = torch.einsum("gsec,gsm->egcm", dispatch_mask.float(), input)
     dispatched_input = dispatched_input.contiguous()
     chunks = list(dispatched_input.chunk(self.world_size))
     dist.all_to_all(chunks, chunks, self.group)
     return dispatched_input
示例#9
0
print_all("rank = %d, size = %d" % (rank, size))

print_all("all_to_all_single with empty tensors")
dist.all_to_all_single(torch.empty([0]), torch.empty([0]))

dist.barrier()

print_all("scatter using alltoall")
if rank == 1:
    x = [torch.ones([2]) * (r + 1) for r in range(size)]
else:
    x = [torch.zeros([0]) for _ in range(size)]
y = [torch.zeros([2]) if r == 1 else torch.zeros([0]) for r in range(size)]
print_all("x = %s" % x)
print_all("y = %s" % y)
dist.all_to_all(y, x)
print_all("y = %s" % y)

dist.barrier()

print_all("gather using alltoall")
x = [
    torch.ones([2]) * (rank + 1) if r == 1 else torch.zeros([0])
    for r in range(size)
]
if rank == 1:
    y = [torch.zeros([2]) for _ in range(size)]
else:
    y = [torch.zeros([0]) for _ in range(size)]
print_all("x = %s" % x)
print_all("y = %s" % y)