def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, master_dtype=tf.float32, slice_dtype=tf.float32, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor master_dtype: a tf.dtype slice_dtype: a tf.dtype name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), master_dtype=master_dtype, slice_dtype=slice_dtype, activation_dtype=x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def allreduce_ring(xs, devices, reduction_fn_string="SUM"): """Compute the reduction of all Tensors and put the result everywhere. Performance-optimized for a ring of devices. Args: xs: a list of n tf.Tensors devices: a list of strings reduction_fn_string: "SUM" or "MAX" Returns: a list of n Tensors Raises: ValueError: if devices is not a list of n strings """ n = len(xs) if len(devices) != n: raise ValueError("devices must be a list of length len(xs)") if n == 1: return xs shape = xs[0].shape.as_list() # tf.logging.info("allreduce_ring shape = %s" % shape) size = None if None in shape else mtf.list_product(shape) if size is None or size < 1024 or size % n != 0: return allreduce_ring_single_shard(xs, devices, reduction_fn_string) def _circular_shift(l, n): n %= len(l) return l[-n:] + l[:-n] def _flatten_and_split(x): return tf.split(tf.reshape(x, [size]), n) def _concat_and_reshape(xs): return tf.reshape(tf.concat(xs, 0), shape) # [device, shard] x_split = mtf.parallel(devices, _flatten_and_split, xs) x_split_t = mtf.transpose_list_of_lists(x_split) y_split_t = [] for shard in xrange(n): shard_xs = _circular_shift(x_split_t[shard], shard) shard_devices = _circular_shift(devices, shard) shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices, reduction_fn_string) y_split_t.append(_circular_shift(shard_ys, -shard)) y_split = mtf.transpose_list_of_lists(y_split_t) ys = mtf.parallel(devices, _concat_and_reshape, y_split) return ys