コード例 #1
0
ファイル: gloo_collective_group.py プロジェクト: wxj0916/ray
 def collective_fn(input_tensor, output_tensor, context):
     pygloo.allreduce(
         context, gloo_util.get_tensor_ptr(input_tensor),
         gloo_util.get_tensor_ptr(output_tensor),
         gloo_util.get_tensor_n_elements(input_tensor),
         gloo_util.get_gloo_tensor_dtype(input_tensor),
         gloo_util.get_gloo_reduce_op(allreduce_options.reduceOp))
コード例 #2
0
ファイル: gloo_collective_group.py プロジェクト: krfricke/ray
 def p2p_fn(tensor, context, peer):
     pygloo.recv(
         context,
         gloo_util.get_tensor_ptr(tensor),
         gloo_util.get_tensor_n_elements(tensor),
         gloo_util.get_gloo_tensor_dtype(tensor),
         peer,
     )
コード例 #3
0
ファイル: gloo_collective_group.py プロジェクト: krfricke/ray
 def collective_fn(input_tensor, output_tensor, context):
     pygloo.allgather(
         context,
         gloo_util.get_tensor_ptr(input_tensor),
         gloo_util.get_tensor_ptr(output_tensor),
         gloo_util.get_tensor_n_elements(input_tensor),
         gloo_util.get_gloo_tensor_dtype(input_tensor),
     )
コード例 #4
0
ファイル: gloo_collective_group.py プロジェクト: wxj0916/ray
 def collective_fn(input_tensor, output_tensor, context):
     size = gloo_util.get_tensor_n_elements(input_tensor)
     world_size = self._gloo_context.size
     pygloo.reduce_scatter(
         context, gloo_util.get_tensor_ptr(input_tensor),
         gloo_util.get_tensor_ptr(output_tensor), size,
         [size // world_size for _ in range(world_size)],
         gloo_util.get_gloo_tensor_dtype(output_tensor),
         gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp))
コード例 #5
0
ファイル: gloo_collective_group.py プロジェクト: krfricke/ray
 def collective_fn(input_tensor, output_tensor, context):
     pygloo.broadcast(
         context,
         gloo_util.get_tensor_ptr(input_tensor),
         gloo_util.get_tensor_ptr(output_tensor),
         gloo_util.get_tensor_n_elements(input_tensor),
         gloo_util.get_gloo_tensor_dtype(input_tensor),
         root_rank,
     )
コード例 #6
0
def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists):
    """Check the compatibility between tensor input and tensor list input."""
    if not tensors or not isinstance(tensors, list):
        raise RuntimeError("The first argument 'tensors' expects a list of tensors.")

    if len(tensors) != 1:
        raise RuntimeError(
            "Gloo only accept one tensor in the first argument 'tensors'."
            " Got {} != 1.".format(len(tensors))
        )

    if not tensor_lists or not isinstance(tensor_lists, list):
        raise RuntimeError(
            "The second argument 'tensor_lists' expects a list of tensor list."
        )

    if len(tensor_lists) != 1:
        raise RuntimeError(
            "Gloo only accept one tensor list "
            "in the second argument 'tensor_lists'."
            " Got {} != 1.".format(len(tensor_lists))
        )

    dtype = gloo_util.get_gloo_tensor_dtype(tensors[0])
    shape = gloo_util.get_tensor_shape(tensors[0])

    # check all tensors in `tensor_lists` match.
    for t in tensor_lists[0]:
        # check dtype
        dt = gloo_util.get_gloo_tensor_dtype(t)
        if dt != dtype:
            raise RuntimeError(
                "All tensor operands to scatter/gather must "
                "have the same dtype. Got '{}' and '{}'.".format(dt, dtype)
            )
        s = gloo_util.get_tensor_shape(t)
        if s != shape:
            raise RuntimeError(
                "All tensor operands to scatter/gather must "
                "have the same shape. Got '{}' and '{}'.".format(s, shape)
            )