def collective_fn(input_tensor, output_tensor, context): pygloo.allreduce( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), gloo_util.get_gloo_reduce_op(allreduce_options.reduceOp))
def p2p_fn(tensor, context, peer): pygloo.recv( context, gloo_util.get_tensor_ptr(tensor), gloo_util.get_tensor_n_elements(tensor), gloo_util.get_gloo_tensor_dtype(tensor), peer, )
def collective_fn(input_tensor, output_tensor, context): pygloo.allgather( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), )
def collective_fn(input_tensor, output_tensor, context): size = gloo_util.get_tensor_n_elements(input_tensor) world_size = self._gloo_context.size pygloo.reduce_scatter( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), size, [size // world_size for _ in range(world_size)], gloo_util.get_gloo_tensor_dtype(output_tensor), gloo_util.get_gloo_reduce_op(reducescatter_options.reduceOp))
def collective_fn(input_tensor, output_tensor, context): pygloo.broadcast( context, gloo_util.get_tensor_ptr(input_tensor), gloo_util.get_tensor_ptr(output_tensor), gloo_util.get_tensor_n_elements(input_tensor), gloo_util.get_gloo_tensor_dtype(input_tensor), root_rank, )
def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): """Check the compatibility between tensor input and tensor list input.""" if not tensors or not isinstance(tensors, list): raise RuntimeError("The first argument 'tensors' expects a list of tensors.") if len(tensors) != 1: raise RuntimeError( "Gloo only accept one tensor in the first argument 'tensors'." " Got {} != 1.".format(len(tensors)) ) if not tensor_lists or not isinstance(tensor_lists, list): raise RuntimeError( "The second argument 'tensor_lists' expects a list of tensor list." ) if len(tensor_lists) != 1: raise RuntimeError( "Gloo only accept one tensor list " "in the second argument 'tensor_lists'." " Got {} != 1.".format(len(tensor_lists)) ) dtype = gloo_util.get_gloo_tensor_dtype(tensors[0]) shape = gloo_util.get_tensor_shape(tensors[0]) # check all tensors in `tensor_lists` match. for t in tensor_lists[0]: # check dtype dt = gloo_util.get_gloo_tensor_dtype(t) if dt != dtype: raise RuntimeError( "All tensor operands to scatter/gather must " "have the same dtype. Got '{}' and '{}'.".format(dt, dtype) ) s = gloo_util.get_tensor_shape(t) if s != shape: raise RuntimeError( "All tensor operands to scatter/gather must " "have the same shape. Got '{}' and '{}'.".format(s, shape) )