def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.weights_nonzero(inputs) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testLayerNorm(self): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf_layers.layer_norm(mtf_inputs, dim=channels_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.layer_norm(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() self.evaluate(init) self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def create_positional_emb_2d(self, targets, max_length_dim, model_dim): """Learned 2d positional embedding for images.""" mesh = targets.mesh hparams = self._hparams activation_dtype = self.set_activation_type() rows_dim = mtf.Dimension("rows", hparams.img_len) cols_dim = mtf.Dimension("cols", hparams.img_len * hparams.num_channels) positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) return position_x + position_y
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.infeed(mesh, inputs, shape=mtf.TensorShape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[1], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.outfeed(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual, expected = sess.run([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.infeed(mesh, inputs, shape=mtf.TensorShape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[1], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.outfeed(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension(name="x", size=5),), (("x", 5),), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension(name="x", size=4), mtf.Dimension(name="y", size=8)]),), ("x:4;y:8",), ("x:4.y:8",), ("x:4 y:8",), ("x:4,y:8",), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual(shape, mtf.Shape([mtf.Dimension(name="x", size=4), mtf.Dimension(name="y", size=8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]),), ("d_ff:model;heads:model",), ("d_ff:model.heads:model",), ("d_ff:model heads:model",), ("d_ff:model,heads:model",), ([("d_ff", "model"), ("heads", "model")],), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads")
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size) io_dim = mtf.Dimension('io', FLAGS.io_size) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False) y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return y, loss
def batch_dims(self): hparams = self._hparams if hparams.outer_batch_size == 0: return [mtf.Dimension("batch", hparams.batch_size)] else: if hparams.batch_size % hparams.outer_batch_size != 0: raise ValueError( "hparams.outer_batch_size must divide hparams.batch_size") return [ mtf.Dimension("outer_batch", hparams.outer_batch_size), mtf.Dimension("inner_batch", hparams.batch_size // hparams.outer_batch_size) ]
def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size) ** -0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def multihead_attention_vars( mesh, heads, io_channels, kv_channels, activation_dtype): """Create Parameters for Multihead Attention. Args: mesh: a Mesh heads: a Dimension io_channels: a Dimension kv_channels: a Dimension activation_dtype: a tf.dtype Returns: q_var: a Tensor with shape [heads, io_channels, kv_channels] k_var: a Tensor with shape [heads, io_channels, kv_channels] v_var: a Tensor with shape [heads, io_channels, kv_channels] o_var: a Tensor with shape [heads, io_channels, kv_channels] """ qkvo = mtf.Dimension("qkvo", 4) qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25) v_stddev = io_channels.size ** -0.5 o_stddev = (io_channels.size * heads.size) ** -0.5 def qkvo_initializer(shape, dtype=None, partition_info=None, verify_shape=None): del partition_info, verify_shape return tf.random_normal(shape, dtype=dtype) * tf.reshape( [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1]) var = mtf.get_variable( mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]), initializer=qkvo_initializer, activation_dtype=activation_dtype) q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo) return q_var, k_var, v_var, o_var
def attention_bias_local_block(mesh, block_length, memory_length, dtype=tf.int32): """Bias for attention for local blocks where attention to right is disallowed. Create the bias matrix by using two separate masks, one for the memory part which doesn't overlap with the query and second which interacts with the query and should be disallowed to look to the right of the current query position. Args: mesh: a MeshTensorflow object block_length: a mtf.Dimension memory_length: a mtf.Dimension dtype: a tf.dtype Returns: a mtf.Tensor with shape [block_length, memory_length] """ memory_length = mtf.Dimension(memory_length.name, block_length.size) memory_mask = mtf.zeros(mesh, [block_length, memory_length], dtype=dtype) mask = mtf.cast(mtf.less(mtf.range(mesh, block_length, dtype=dtype), mtf.range(mesh, memory_length, dtype=dtype)), dtype=dtype) mask = mtf.cast(mtf.concat([memory_mask, mask], memory_length.name), dtype=tf.float32) * -1e9 return mask
def _concat_equal_sizes(xs, dim, new_dim_name): axis = xs[0].shape.dims.index(dim) ret = mtf.stack(xs, "tmp_concat", axis) new_shape = mtf.TensorShape( xs[0].shape.dims[:axis] + [mtf.Dimension(new_dim_name, dim.size * len(xs))] + xs[0].shape.dims[axis + 1:]) return mtf.reshape(ret, new_shape)
def mtf_model_fn(self, features, mesh): hparams = self._hparams # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size]) tf_x = tf.matmul( tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size), [hparams.batch_size, 1]), tf.reshape(tf.lin_space(0., 1.0, hparams.io_size), [1, hparams.io_size])) batch_dim = mtf.Dimension("batch", hparams.batch_size) hidden_dim = mtf.Dimension("hidden", hparams.hidden_size) io_dim = mtf.Dimension("io", hparams.io_size) x = mtf.infeed_fully_replicated(mesh, tf_x, mtf.TensorShape([batch_dim, io_dim])) h = mtf_layers.dense(x, hidden_dim, name="layer1", use_bias=False) y = mtf_layers.dense(h, io_dim, name="layer2", use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return None, loss
def test_variable_placer(self): sizes = [100, 0, 0, 0] device_list = ['cpu:0', 'cpu:1', 'cpu:2', 'cpu:3'] with tf.Graph().as_default() as g: var_placer = mtf_utils.BalancedVariablePlacer(device_list, sizes) graph = mtf.Graph() mesh = mtf.Mesh(graph, 'my_mesh', var_placer) hidden_dim = mtf.Dimension('hidden', 10) output_dim = mtf.Dimension('output_feature', 10) for i in xrange(5): # Each variable takes 400 Bytes, and will be placed from cpu:1. mtf.get_variable(mesh, 'w{}'.format(i), [hidden_dim, output_dim]) for i in xrange(5): var = g.get_tensor_by_name('w{}:0'.format(i)) device = (i + 1) % len(device_list) self.assertEqual('cpu:{}'.format(device), var.device)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) h1 = mtf_layers.dense( x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf_layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, query.shape)
def testDotProductAttention(self, batch, heads, length_q, length_kv, depth_k, depth_v): query = tf.random_normal([batch, heads, length_q, depth_k]) key = tf.random_normal([batch, heads, length_kv, depth_k]) value = tf.random_normal([batch, heads, length_kv, depth_v]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) heads_dim = mtf.Dimension("heads", heads) length_q_dim = mtf.Dimension("length_q", length_q) length_kv_dim = mtf.Dimension("length_kv", length_kv) depth_k_dim = mtf.Dimension("depth_k", depth_k) depth_v_dim = mtf.Dimension("depth_v", depth_v) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim])) mtf_key = mtf.import_tf_tensor( mesh, key, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_k_dim])) mtf_value = mtf.import_tf_tensor( mesh, value, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_v_dim])) mtf_outputs = mtf_layers.dot_product_attention(mtf_query, mtf_key, mtf_value, mask=None) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, block_length): length_q = length length_m = length query = tf.random_normal([batch, length_q, io_channels]) memory = tf.random_normal([batch, length_m, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) length_m_dim = mtf.Dimension("length_m", length_m) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_memory = mtf.import_tf_tensor( mesh, memory, shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim])) mtf_outputs = mtf_layers.masked_local_attention_1d( mtf_query, mtf_memory, kv_channels=kv_channels_dim, heads=heads_dim, block_length=block_length) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
def testMeshImpl(self): shape = mtf.Shape([mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertEqual(mesh_impl.ndims, len(shape)) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
def moe_v0(inputs, hidden_dim, output_dim, experts_dim, loss_coef=1e-3, overhead=1.0): """Local mixture of experts that works well on TPU. See https://arxiv.org/abs/1701.06538 There are num_experts expert networks, each containing a relu-activated hidden layer of size hidden_size, followed by an output projection. The number of parameters is thus: num_experts * (input_size * hidden_size + hidden_size * output_size) The input is 3d: [batch, length, depth], consisting of the representations of all positions in a batch of sequences. Each position of each sequence is sent to 0-2 experts. The expert choices and the combination weights are determined by a learned gating function. This function returns a small auxiliary loss that should be added to the training loss of the model. This loss helps to balance expert usage. Without the loss, it is very likely that a few experts will be trained and the rest will starve. Several hacks are necessary to get around current TPU limitations: - To ensure static shapes, we enforce (by truncation/padding) that each sequence send the same number of elements to each expert. It would make more sense to enforce this equality over the entire batch, as opposed to on individual sequences. This would allow more freedom for individual sequences to be unbalanced. Unfortunately, that would slow down our hacked-up gather-by-matmul implementation. TODO(noam): There is no real reason for a single sequence to be the unit of equal allocation. Reshaping the inputs would allow us to pick a different unit of equal allocation. TODO(noam): Factor this code better. We want to be able to substitute different code for the experts themselves. We also want to integrate this gating/dispatching logic into multi-device mixtures-of-experts. Args: inputs: a mtf.Tensor with shape [batch_dim, length_dim, input_dim] hidden_dim: a mtf.Dimension output_dim: a mtf.Dimension experts_dim: a mtf.Dimension loss_coef: a float scalar overhead: multiplicative factor of how much spare capacity to assign Returns: outputs: a Tensor with shape [batch_dim, length_dim, output_dim] loss: a mtf scalar """ batch_dim, length_dim, input_dim = inputs.shape.dims # Each sequence sends expert_capacity positions to each expert. expert_capacity = min( length_dim.size, int((length_dim.size * 2 * overhead) / experts_dim.size)) expert_capacity_dim = mtf.Dimension("expert_capacity", expert_capacity) experts_dim_unsplit = mtf.Dimension("expert_unsplit", experts_dim.size) batch_dim_unsplit = mtf.Dimension("batch_unsplit", batch_dim.size) # This is the learned gating function. # shape = [batch_dim, length_dim, experts_dim_unsplit] gates = mtf.softmax(dense(inputs, experts_dim_unsplit), experts_dim_unsplit) assignment_shape = mtf.TensorShape( [batch_dim, length_dim, experts_dim_unsplit, expert_capacity_dim]) backward_assignment = mtf.slicewise(functools.partial( _truncated_top_2_gating, expert_capacity=expert_capacity), [gates], output_shape=assignment_shape, splittable_dims=[batch_dim], name="backward_assignment") forward_assignment = mtf.cast(mtf.cast(backward_assignment, tf.bool), inputs.dtype) # put num_experts dimension first to make split easier in alltoall expert_inputs = mtf.einsum([inputs, forward_assignment], mtf.TensorShape([ experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim ])) expert_inputs = mtf.reshape( expert_inputs, mtf.TensorShape( [experts_dim, batch_dim_unsplit, expert_capacity_dim, input_dim])) # Now feed the expert inputs through the experts. h = dense(expert_inputs, hidden_dim, expert_dims=[experts_dim], activation=mtf.relu, name="x0") expert_output = dense(h, output_dim, expert_dims=[experts_dim], name="x1") expert_output = mtf.reshape( expert_output, mtf.TensorShape( [experts_dim_unsplit, batch_dim, expert_capacity_dim, input_dim])) output = mtf.einsum([expert_output, backward_assignment], mtf.TensorShape([batch_dim, length_dim, output_dim])) importance = mtf.reduce_sum(backward_assignment, output_shape=mtf.TensorShape( [batch_dim, experts_dim_unsplit])) loss = cv_squared(importance) * loss_coef return output, loss
def masked_local_attention_1d(query_antecedent, memory_antecedent, kv_channels, heads, block_length=128, name=None): """Attention to the source position and a neighborhood to the left of it. The sequence is divided into blocks of length block_size. Attention for a given query position can only see memory positions less than or equal to the query position, in the corresponding block and the previous block. Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional). Currently, memory_length must have the same size as query_length, but a different name. kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) block_length: an integer, representing receptive fields for attention. name: an optional string. Returns: a Tensor of shape [batch, query_length, io_channels] Raises: ValueError: if channels or depth don't match. """ with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): batch, query_length, io_channels = query_antecedent.shape.dims q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") # Get query q, keys k and values v. q = mtf.einsum([query_antecedent, q_var], mtf.TensorShape( [batch, heads, query_length, kv_channels])) k = mtf.einsum([memory_antecedent, k_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) v = mtf.einsum([memory_antecedent, v_var], mtf.TensorShape( [batch, heads, memory_length, kv_channels])) # Let's assume for now we don't have padding and the block length equally # divides the memory length. block_length = (query_length.size if query_length.size < block_length * 2 else block_length) blength = mtf.Dimension("block_length", block_length) mlength = mtf.Dimension("mem_block_length", block_length) num_blocks = mtf.Dimension("num_blocks", query_length.size // block_length) q = mtf.reshape( q, mtf.TensorShape([batch, heads, num_blocks, blength, kv_channels])) k = mtf.reshape( k, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) v = mtf.reshape( v, mtf.TensorShape([batch, heads, num_blocks, mlength, kv_channels])) # compute attention for the first query block. def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.TensorShape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.TensorShape([batch, heads, block, blength, kv_channels])) return first_output # Attention for first block, since query_length = key_length. first_output = first_block_attention() # Concatenate two adjacent blocks to compute the overlapping memory block. def local(x): """Helper function to get memory blocks.""" prev_block = mtf.slice(x, 0, num_blocks.size - 1, num_blocks.name) cur_block = mtf.slice(x, 1, num_blocks.size - 1, num_blocks.name) local_block = mtf.concat([prev_block, cur_block], mlength.name) return local_block local_k = local(k) local_v = local(v) mblocks = local_k.shape.dims[2] mlength = local_k.shape.dims[3] # Calculate the causal mask to avoid peeking into the future. We compute # this once and reuse it for all blocks since the block_size is known. mask = attention_bias_local_block(query_antecedent.mesh, blength, mlength) # Remove the first block from q since we already computed that. tail_q = mtf.slice(q, 1, num_blocks.size - 1, num_blocks.name) # Compatibility between q and k for rest of the blocks. # Shape [batch, heads, num_blocks - 1, block_length, local_length] attention = mtf.einsum([tail_q, local_k], mtf.TensorShape( [batch, heads, mblocks, blength, mlength])) attention += mask attention = mtf.softmax(attention, mlength) # Run attention for rest of the blocks. # Shape [batch, heads, num_blocks-1, block_length, kv_channels] output = mtf.einsum([attention, local_v], mtf.TensorShape( [batch, heads, mblocks, blength, kv_channels])) # Now concatenate the first and rest of the blocks. final_output = mtf.concat([first_output, output], num_blocks.name) final_output = mtf.reshape( final_output, mtf.TensorShape([batch, heads, query_length, kv_channels])) return mtf.einsum([final_output, o_var], mtf.TensorShape([batch, query_length, io_channels]))
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. # TODO(noam): get spatially-partitioned convolution working. fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d(x, kernel1)) f2 = mtf.relu(mtf.conv2d(f1, kernel2)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf_layers.dense(x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def experts_dim(self): return mtf.Dimension("experts", self._hparams.moe_num_experts)
def kv_dim(self): return mtf.Dimension("d_kv", self._hparams.d_kv)
def heads_dim(self): return mtf.Dimension("heads", self._hparams.num_heads)
def memory_length_dim(self): return mtf.Dimension("memory_length", self._hparams.max_length)
def _decoder_layer_stack_incremental(self, x, step_num, encdec_tensors, self_attention_k, self_attention_v, encdec_attention_mask=None): """Decoder layer stack during inference. We are processing only one position at a time. The self-attention keys and values have already been computed for previous positions. In addition to the decoder output, we need to produce the updated self-attention keys and values. If there is an encoder, then additional Tensors are supplied in encdec_tensors, which give us the keys and values for encoder-decoder attention as well as the weight matrices q_var and o_var. Args: x: a mtf.Tensor with shape [batch_dim, model_dim] step_num: an mtf integer Scalar encdec_tensors: an optional list of num_layers tuples, each of the form (q_var, o_var, k, v) self_attention_k: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] self_attention_v: an optional list of num_layers Tensors each with shape [batch, heads, memory_length, kv_channels] encdec_attention_mask: an optional mtf.Tensor with shape [batch, length_dim, encoder_length_dim] containing values 0 or -inf. Returns: y: a mtf.Tensor with shape [batch_dim, model_dim] new_self_attention_k: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_k new_self_attention_v: a list of num_layers mtf.Tensors, with the same shapes as the elements of self_attention_v Raises: ValueError: if hparams make no sense """ hparams = self._hparams num_layers = hparams.num_decoder_layers num_layer_norms = num_layers * (2 if encdec_tensors is None else 3) + 1 layer_norms_dim = mtf.Dimension("layer_norms", num_layer_norms) layer_norm_combined_var = mtf.get_variable( x.mesh, "layer_norm_scale", mtf.Shape([layer_norms_dim, self.model_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) layer_norm_vars = mtf.unstack(layer_norm_combined_var, layer_norms_dim) def normalize(x): scale = layer_norm_vars.pop(0) variance = mtf.reduce_mean(mtf.square(x), reduced_dim=self.model_dim) return x * mtf.rsqrt(variance + hparams.norm_epsilon) * scale new_self_attention_k = [] new_self_attention_v = [] for layer in range(num_layers): with tf.variable_scope("layer_%d" % layer): # Self attention layer y, new_k, new_v = mtf_layers.multihead_self_attention_incremental( normalize(x), prev_k=self_attention_k[layer], prev_v=self_attention_v[layer], step_num=step_num, name="self_attention") new_self_attention_k.append(new_k) new_self_attention_v.append(new_v) x += y if encdec_tensors is not None: # Encoder-Decoder attention layer q_var, o_var, k, v = encdec_tensors[layer] x += mtf_layers.multihead_encdec_attention_incremental( normalize(x), q_var, o_var, k, v, encdec_attention_mask, name="encdec_attention") # ffn layer x += self._feedforward_layer(normalize(x), hparams) x = normalize(x) assert not layer_norm_vars return x, new_self_attention_k, new_self_attention_v
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension("x", 5), ), (("x", 5), ), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]), ), ("x:4;y:8", ), ("x:4.y:8", ), ("x:4 y:8", ), ("x:4,y:8", ), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape([]) self.assertEqual(shape.dims, []) shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]), ), ("d_ff:model;heads:model", ), ("d_ff:model.heads:model", ), ("d_ff:model heads:model", ), ("d_ff:model,heads:model", ), ([("d_ff", "model"), ("heads", "model")], ), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads") def testTensorLayout(self): tensor_layout = mtf.TensorLayout([0, 2, 1]) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(0), ()) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(1), (0, )) self.assertEqual(tensor_layout.mesh_axis_to_tensor_axis(2), (0, 2)) tensor_layout = mtf.TensorLayout([None, 0]) self.assertFalse(tensor_layout.is_fully_replicated) tensor_layout = mtf.TensorLayout([None, None, None]) self.assertTrue(tensor_layout.is_fully_replicated) def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2) def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) with self.test_session() as sess: inputs_value, outputs_value = sess.run([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters() def testMesh(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") self.assertEqual(mesh.graph, graph) def testMeshImpl(self): shape = mtf.Shape( [mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertEqual(mesh_impl.ndims, len(shape)) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual( mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual( shape, mtf.Shape([mtf.Dimension("x", 4), mtf.Dimension("y", 8)]))
def feedforward_dim(self): return mtf.Dimension("d_ff", self._hparams.d_ff)