def allreduce_ring(xs, devices, reduction_fn_string="SUM"): """Compute the reduction of all Tensors and put the result everywhere. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of strings reduction_fn_string: "SUM" or "MAX" Returns: a list of n Tensors Raises: ValueError: if devices is not a list of n strings """ n = len(xs) if len(devices) != n: raise ValueError("devices must be a list of length len(xs)") if n == 1: return xs shape = xs[0].shape.as_list() # tf.logging.info("allreduce_ring shape = %s" % shape) size = None if None in shape else mtf.list_product(shape) if size is None or size < 1024 or size % n != 0: return allreduce_ring_single_shard(xs, devices, reduction_fn_string) def _circular_shift(l, n): n %= len(l) return l[-n:] + l[:-n] def _flatten_and_split(x): # tf.reshape treats [-1] as a special value denoting 1D flattening. return tf.split(tf.reshape(x, [-1]), n) def _concat_and_reshape(xs): return tf.reshape(tf.concat(xs, 0), shape) # [device, shard] x_split = mtf.parallel(devices, _flatten_and_split, xs) x_split_t = mtf.transpose_list_of_lists(x_split) y_split_t = [] for shard in xrange(n): shard_xs = _circular_shift(x_split_t[shard], shard) shard_devices = _circular_shift(devices, shard) shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices, reduction_fn_string) y_split_t.append(_circular_shift(shard_ys, -shard)) y_split = mtf.transpose_list_of_lists(y_split_t) ys = mtf.parallel(devices, _concat_and_reshape, y_split) return ys
def assign_to_slices(self, assign_fn, values, assign_to_tensor_list=None): """Assign to the slice variables. Args: assign_fn: a function from (mtf.Variable, tf.Variable, tf.Tensor) -> tf.Operation values: a list of tf.Tensor assign_to_tensor_list: an optional list of tf.Variable Returns: a tf.operation """ if assign_to_tensor_list is None: assign_to_tensor_list = self._laid_out_tensor.all_slices # Handle both N -> 1 and N -> N cases. num_slices = min(len(assign_to_tensor_list), len(values)) devices = [""] * num_slices return tf.group( mtf.parallel(devices, assign_fn, [self._variable] * len(devices), assign_to_tensor_list[:num_slices], values[:num_slices]))
def allconcat_ring(xs, devices, concat_axis): """Concatenate all Tensors everywhere. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of n strings concat_axis: an integer Returns: a list of n Tensors """ n = len(xs) if n == 1: return xs # [target, source] parts = [[ xs[target] if target == source else None for source in xrange(n) ] for target in xrange(n)] for distance in xrange(1, n // 2 + 1): for target in xrange(n): source = (target + distance) % n if parts[target][source] is None: with tf.device(devices[target]): parts[target][source] = tf.identity(parts[(target + 1) % n][source]) source = (target - distance) % n if parts[target][source] is None: with tf.device(devices[target]): parts[target][source] = tf.identity(parts[(target - 1) % n][source]) return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ if fn == tf.add: assert len(inputs) == 2 if isinstance(inputs[0], mtf.LazyAllreduceSum): # sum of LazyAllreduceSum (keep delaying the allreduce) return inputs[0] + inputs[1] # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [ x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs ] ret = mtf.parallel(self.devices, fn, *inputs) if isinstance(ret[0], tuple): ret = mtf.transpose_list_of_lists(ret) return tuple([self.LaidOutTensor(t) for t in ret]) else: return self.LaidOutTensor(ret)
def to_tf_tensor(self, *inputs): inputs_save = inputs[1] inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs] if inputs[1][0].name == "logits/add_1/parallel_0_1/Add:0": return inputs try: if inputs[0][0].shape[-1] != inputs[1][0].shape[-1]: tensor_axis = len(inputs[0][0].shape)-1 ret = mtf.parallel( ["GPU:0"], tf.concat, [inputs[0]], axis=[tensor_axis] * len(["GPU:0"])) slices = ret else: return inputs_save for i in range(0,len(slices)): while slices[i].shape[0] == 3 and len(slices[i].shape) < 5: slices[i] = tf.reduce_mean(slices[i], axis=[0]) return slices except: return inputs_save return inputs
def alltoall_ring(xs, devices, split_axis, concat_axis): """MPI alltoall operation. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of n strings split_axis: an integer concat_axis: an integer Returns: a list of n Tensors """ n = len(xs) if n == 1: return xs # set up # [target, source] parts = [[None] * n for i in xrange(n)] def my_split(x, size_splits): total_size = tf.shape(x)[split_axis] part_size = total_size // sum(size_splits) return tf.split(x, [s * part_size for s in size_splits], axis=split_axis) forward_message_size = (n - 1) // 2 backward_message_size = (n - 1) - forward_message_size forward_messages = [None] * n backward_messages = [None] * n for i in xrange(n): with tf.device(devices[i]): if i >= backward_message_size: a, b, c, d = my_split( xs[i], [i - backward_message_size, backward_message_size, 1, n - i - 1]) backward_messages[i] = b parts[i][i] = c forward_messages[i] = tf.concat([d, a], axis=split_axis) else: a, b, c, d = my_split( xs[i], [i, 1, forward_message_size, backward_message_size - i]) backward_messages[i] = tf.concat([d, a], axis=split_axis) parts[i][i] = b forward_messages[i] = c for step in xrange(1, max(forward_message_size, backward_message_size) + 1): new_forward_messages = [None] * n new_backward_messages = [None] * n for i in xrange(n): with tf.device(devices[i]): if forward_message_size > 0: parts[i][(i - step) % n], new_forward_messages[i] = my_split( forward_messages[(i - 1) % n], [1, forward_message_size - 1]) if backward_message_size > 0: new_backward_messages[i], parts[i][(i + step) % n] = my_split( backward_messages[(i + 1) % n], [backward_message_size - 1, 1]) forward_message_size -= 1 backward_message_size -= 1 forward_messages = new_forward_messages backward_messages = new_backward_messages return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
def to_tf_tensor2(self, *inputs): inputs_save = inputs[1] inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs] tensor_axis = len(inputs[0][0].shape)-1 ret = mtf.parallel( ["GPU:0"], tf.concat, [inputs[0]], axis=[tensor_axis] * len(["GPU:0"])) return ret
def alltoall_pointtwise(xs, devices, split_axis, concat_axis): """MPI alltoall operation. Implementation of alltoall using pointwise communication. Args: xs: a list of n tf.Tensors devices: a list of n strings split_axis: an integer concat_axis: an integer Returns: a list of n Tensors """ n = len(xs) if n == 1: return xs # [target, source] parts = mtf.transpose_list_of_lists( mtf.parallel(devices, tf.split, xs, [n] * n, axis=[split_axis] * n)) return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
def assign_to_slices(self, assign_fn, values): """Assign to the slice variables. Args: assign_fn: a function from (mtf.Variable, tf.Variable, tf.Tensor) -> tf.Operation values: a list of tf.Tensor Returns: a tf.operation """ return tf.group(mtf.parallel( self._mesh_impl.devices, assign_fn, [self._variable] * len(values), self.laid_out_tensor.all_slices, values))