def reducescatter_multigpu(output_tensor_list, input_tensor_lists, group_name: str = "default", op=types.ReduceOp.SUM): """Reducescatter a list of tensors across all GPUs. Args: output_tensor_list: the resulted list of tensors, with shape: num_gpus * shape(tensor). input_tensor_lists: the original tensors, with shape: num_gpus * world_size * shape(tensor). group_name (str): the name of the collective group. op: The reduce operation. Returns: None. """ if not types.cupy_available(): raise RuntimeError("Multigpu calls requires NCCL and Cupy.") _check_tensor_lists_input(input_tensor_lists) _check_tensor_list_input(output_tensor_list) g = _check_and_get_group(group_name) opts = types.ReduceScatterOptions() opts.reduceOp = op g.reducescatter(output_tensor_list, input_tensor_lists, opts)
def reducescatter(tensor, tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM): """Reducescatter a list of tensors across the group. Reduce the list of the tensors across each process in the group, then scatter the reduced list of tensors -- one tensor for each process. Args: tensor: the resulted tensor on this process. tensor_list (list): The list of tensors to be reduced and scattered. group_name (str): the name of the collective group. op: The reduce operation. Returns: None """ _check_single_tensor_input(tensor) _check_tensor_list_input(tensor_list) g = _check_and_get_group(group_name) if len(tensor_list) != g.world_size: raise RuntimeError( "The length of the tensor list operands to reducescatter " "must not be equal to world_size.") opts = types.ReduceScatterOptions() opts.reduceOp = op g.reducescatter([tensor], [tensor_list], opts)