def _check_inputs_compatibility_for_scatter_gather(tensors, tensor_lists): """Check the compatibility between tensor input and tensor list input.""" if not tensors or not isinstance(tensors, list): raise RuntimeError("The first argument 'tensors' expects a list of tensors.") if len(tensors) != 1: raise RuntimeError( "Gloo only accept one tensor in the first argument 'tensors'." " Got {} != 1.".format(len(tensors)) ) if not tensor_lists or not isinstance(tensor_lists, list): raise RuntimeError( "The second argument 'tensor_lists' expects a list of tensor list." ) if len(tensor_lists) != 1: raise RuntimeError( "Gloo only accept one tensor list " "in the second argument 'tensor_lists'." " Got {} != 1.".format(len(tensor_lists)) ) dtype = gloo_util.get_gloo_tensor_dtype(tensors[0]) shape = gloo_util.get_tensor_shape(tensors[0]) # check all tensors in `tensor_lists` match. for t in tensor_lists[0]: # check dtype dt = gloo_util.get_gloo_tensor_dtype(t) if dt != dtype: raise RuntimeError( "All tensor operands to scatter/gather must " "have the same dtype. Got '{}' and '{}'.".format(dt, dtype) ) s = gloo_util.get_tensor_shape(t) if s != shape: raise RuntimeError( "All tensor operands to scatter/gather must " "have the same shape. Got '{}' and '{}'.".format(s, shape) )
def _flatten_for_scatter_gather(tensor_list, copy=False): """Flatten the tensor for gather/scatter operations. Args: tensor_list: the list of tensors to be scattered/gathered. copy: whether the copy the tensors in tensor_list into the buffer. Returns: The flattened tensor buffer. """ if not tensor_list: raise RuntimeError("Received an empty list.") t = tensor_list[0] # note we need a numpy dtype here. dtype = gloo_util.get_numpy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t) buffer = numpy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): gloo_util.copy_tensor(buffer[i], tensor) return buffer