def allreduce_ring(xs, devices, reduction_fn_string="SUM"): """Compute the reduction of all Tensors and put the result everywhere. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of strings reduction_fn_string: "SUM" or "MAX" Returns: a list of n Tensors Raises: ValueError: if devices is not a list of n strings """ n = len(xs) if len(devices) != n: raise ValueError("devices must be a list of length len(xs)") if n == 1: return xs shape = xs[0].shape.as_list() # tf.logging.info("allreduce_ring shape = %s" % shape) size = None if None in shape else mtf.list_product(shape) if size is None or size < 1024 or size % n != 0: return allreduce_ring_single_shard(xs, devices, reduction_fn_string) def _circular_shift(l, n): n %= len(l) return l[-n:] + l[:-n] def _flatten_and_split(x): # tf.reshape treats [-1] as a special value denoting 1D flattening. return tf.split(tf.reshape(x, [-1]), n) def _concat_and_reshape(xs): return tf.reshape(tf.concat(xs, 0), shape) # [device, shard] x_split = mtf.parallel(devices, _flatten_and_split, xs) x_split_t = mtf.transpose_list_of_lists(x_split) y_split_t = [] for shard in xrange(n): shard_xs = _circular_shift(x_split_t[shard], shard) shard_devices = _circular_shift(devices, shard) shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices, reduction_fn_string) y_split_t.append(_circular_shift(shard_ys, -shard)) y_split = mtf.transpose_list_of_lists(y_split_t) ys = mtf.parallel(devices, _concat_and_reshape, y_split) return ys
def slicewise(self, fn, *inputs): """Execute a function in parallel on all slices. Args: fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors. *inputs: a list of inputs. Each input is either a LaidOutTensor or is convertible to a tf.Tensor. Returns: a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple. """ if fn == tf.add: assert len(inputs) == 2 if isinstance(inputs[0], mtf.LazyAllreduceSum): # sum of LazyAllreduceSum (keep delaying the allreduce) return inputs[0] + inputs[1] # convert all inputs to LaidOutTensor where possible inputs = mtf.convert_args_to_laid_out_tensors(inputs) inputs = [ x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] * len(self.devices) for x in inputs ] ret = mtf.parallel(self.devices, fn, *inputs) if isinstance(ret[0], tuple): ret = mtf.transpose_list_of_lists(ret) return tuple([self.LaidOutTensor(t) for t in ret]) else: return self.LaidOutTensor(ret)
def alltoall_pointtwise(xs, devices, split_axis, concat_axis): """MPI alltoall operation. Implementation of alltoall using pointwise communication. Args: xs: a list of n tf.Tensors devices: a list of n strings split_axis: an integer concat_axis: an integer Returns: a list of n Tensors """ n = len(xs) if n == 1: return xs # [target, source] parts = mtf.transpose_list_of_lists( mtf.parallel(devices, tf.split, xs, [n] * n, axis=[split_axis] * n)) return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)