def _flatten_for_scatter_gather(tensor_list, copy=False): """Flatten the tensor for gather/scatter operations. Args: tensor_list: the list of tensors to be scattered/gathered. copy: whether the copy the tensors in tensor_list into the buffer. Returns: The flattened tensor buffer. """ if not tensor_list: raise RuntimeError("Received an empty list.") t = tensor_list[0] # note we need a numpy dtype here. dtype = gloo_util.get_numpy_tensor_dtype(t) buffer_shape = [len(tensor_list)] + gloo_util.get_tensor_shape(t) buffer = numpy.empty(buffer_shape, dtype=dtype) if copy: for i, tensor in enumerate(tensor_list): gloo_util.copy_tensor(buffer[i], tensor) return buffer
def preprocess_fn(): for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): gloo_util.copy_tensor(input_flattened[i][j], tensor)
def postprocess_fn(): for i, tensor_list in enumerate(tensor_lists): for j, tensor in enumerate(tensor_list): gloo_util.copy_tensor(tensor, output_flattened[i][j])