Example #1
0
def reducescatter_multigpu(output_tensor_list,
                           input_tensor_lists,
                           group_name: str = "default",
                           op=types.ReduceOp.SUM):
    """Reducescatter a list of tensors across all GPUs.

    Args:
        output_tensor_list: the resulted list of tensors, with
            shape: num_gpus * shape(tensor).
        input_tensor_lists: the original tensors, with shape:
            num_gpus * world_size * shape(tensor).
        group_name (str): the name of the collective group.
        op: The reduce operation.

    Returns:
        None.
    """
    if not types.cupy_available():
        raise RuntimeError("Multigpu calls requires NCCL and Cupy.")
    _check_tensor_lists_input(input_tensor_lists)
    _check_tensor_list_input(output_tensor_list)
    g = _check_and_get_group(group_name)
    opts = types.ReduceScatterOptions()
    opts.reduceOp = op
    g.reducescatter(output_tensor_list, input_tensor_lists, opts)
Example #2
0
def reducescatter(tensor,
                  tensor_list: list,
                  group_name: str = "default",
                  op=types.ReduceOp.SUM):
    """Reducescatter a list of tensors across the group.

    Reduce the list of the tensors across each process in the group, then
    scatter the reduced list of tensors -- one tensor for each process.

    Args:
        tensor: the resulted tensor on this process.
        tensor_list (list): The list of tensors to be reduced and scattered.
        group_name (str): the name of the collective group.
        op: The reduce operation.

    Returns:
        None
    """
    _check_single_tensor_input(tensor)
    _check_tensor_list_input(tensor_list)
    g = _check_and_get_group(group_name)
    if len(tensor_list) != g.world_size:
        raise RuntimeError(
            "The length of the tensor list operands to reducescatter "
            "must not be equal to world_size.")
    opts = types.ReduceScatterOptions()
    opts.reduceOp = op
    g.reducescatter([tensor], [tensor_list], opts)