def layer_norm(x, dim, epsilon=1e-6, name="layer_prepostprocess"): """Layer normalization over dimension dim. Args: x: a mtf.Tensor whose shape contains dim. dim: a mtf.Dimension epsilon: a floating point number name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name + "/layer_norm"): scale = mtf.get_variable(x.mesh, "layer_norm_scale", mtf.Shape([dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "layer_norm_bias", mtf.Shape([dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) reduced_shape = x.shape - dim mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) return norm_x * scale + bias
def create_positional_emb_2d(self, targets): """Learned 2d positional embedding for images.""" mesh = targets.mesh positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_type) targets_position_x = mtf.range(mesh, self.rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, self.cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, self.max_length_dim), mtf.Shape([self.rows_dim, self.cols_dim, self.model_dim])) return position_x + position_y
def testGraph(self): graph = mtf.Graph() self.assertLen(graph.operations, 0) self.assertLen(graph.tensors, 0) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) mesh = mtf.Mesh(graph, "mesh_test") _ = mtf.import_tf_tensor(mesh, tf_tensor=tf.constant(0.), shape=mtf.Shape([])) self.assertLen(graph.operations, 1) self.assertLen(graph.tensors, 1) self.assertLen(graph.trainable_variables, 0) self.assertLen(graph.all_variables, 0) _ = mtf.get_variable(mesh, "variable_0", mtf.Shape([]), trainable=True) self.assertLen(graph.operations, 2) self.assertLen(graph.tensors, 2) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 1) _ = mtf.get_variable(mesh, "variable_1", mtf.Shape([]), trainable=False) self.assertLen(graph.operations, 3) self.assertLen(graph.tensors, 3) self.assertLen(graph.trainable_variables, 1) self.assertLen(graph.all_variables, 2)
def _embedding_and_softmax_vars(self, mesh): hparams = self._hparams targets_embedding_var = mtf.get_variable( mesh, "targets_embedding", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_dtype) if self.has_input: if hparams.shared_embedding: inputs_embedding_var = targets_embedding_var else: inputs_embedding_var = mtf.get_variable( mesh, "inputs_embedding", mtf.Shape([self.inputs_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_dtype) else: inputs_embedding_var = None if hparams.shared_embedding_and_softmax_weights: softmax_var = targets_embedding_var * (self.model_dim.size ** -0.5) else: softmax_var = mtf.get_variable( mesh, "softmax", mtf.Shape([self.targets_vocab_dim, self.model_dim]), initializer=tf.random_normal_initializer( stddev=self.model_dim.size**-0.5), activation_dtype=self.activation_dtype) positional_embedding_var = mtf.get_variable( mesh, "positional_embedding", mtf.Shape([self.max_length_dim, self.model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=self.activation_dtype) return (inputs_embedding_var, targets_embedding_var, softmax_var, positional_embedding_var)
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) h1 = mtf_layers.dense( x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense( h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def create_positional_emb_2d(self, targets, max_length_dim, model_dim): """Learned 2d positional embedding for images.""" mesh = targets.mesh hparams = self._hparams activation_dtype = self.set_activation_type() rows_dim = mtf.Dimension("rows", hparams.img_len) cols_dim = mtf.Dimension("cols", hparams.img_len * hparams.num_channels) positional_emb_rows_var = mtf.get_variable( mesh, "positional_emb_rows", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) positional_emb_cols_var = mtf.get_variable( mesh, "positional_emb_cols", mtf.Shape([max_length_dim, model_dim]), initializer=tf.random_normal_initializer(), activation_dtype=activation_dtype) targets_position_x = mtf.range(mesh, rows_dim, dtype=tf.int32) targets_position_y = mtf.range(mesh, cols_dim, dtype=tf.int32) position_x = mtf.broadcast( mtf.gather(positional_emb_rows_var, targets_position_x, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) position_y = mtf.broadcast( mtf.gather(positional_emb_cols_var, targets_position_y, max_length_dim), mtf.Shape([rows_dim, cols_dim, model_dim])) return position_x + position_y
def multihead_attention(query_antecedent, memory_antecedent, mask, kv_channels, heads, dropout=0.0, dropout_broadcast_dims=None, name="multihead_attention"): """Multihead scaled-dot-product attention with input/output transformations. In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). Args: query_antecedent: a mtf.Tensor with shape [batch, query_length, io_channels] memory_antecedent: a mtf.Tensor with shape [batch, memory_length, io_channels] (optional) mask: mask Tensor (see attention_mask()) kv_channels: a mtf.Dimension (the size of the key and value vectors) heads: a mtf.Dimension (the number of heads) dropout: a floating point value dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string. Returns: A mtf.Tensor with shape [batch, query_length, io_channels] Raises: ValueError: if the dimensions do not match. """ batch, query_length, io_channels = query_antecedent.shape.dims with tf.variable_scope(name, default_name="multihead_attention", values=[query_antecedent, memory_antecedent]): q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) if memory_antecedent is None: memory_antecedent = rename_length_to_memory_length( query_antecedent, query_length.name) memory_batch, memory_length, memory_channels = memory_antecedent.shape.dims if memory_batch != batch: raise ValueError("memory batch must equal query batch") if memory_channels != io_channels: raise ValueError("memory channels must equal query channels") q = mtf.einsum( [query_antecedent, q_var], mtf.Shape([batch, heads, query_length, kv_channels])) k = mtf.einsum( [memory_antecedent, k_var], mtf.Shape([batch, heads, memory_length, kv_channels])) v = mtf.einsum( [memory_antecedent, v_var], mtf.Shape([batch, heads, memory_length, kv_channels])) o = dot_product_attention( q, k, v, mask, dropout, dropout_broadcast_dims) return mtf.einsum( [o, o_var], mtf.Shape([batch, query_length, io_channels]))
def batch_norm(x, is_training, momentum, epsilon=1e-9, name=None): """Batch normalization. Args: x: a mtf.Tensor whose shape contains [batch_dim, ..., dim] is_training: a boolean, whether mode is training. momentum: a floating point number, specifying batch norm decay value. epsilon: a floating point number. name: a string. variable scope. Returns: a mtf.Tensor with same shape as x. """ with tf.variable_scope(name, default_name="batch_norm", values=[x]): batch_dim = x.shape.dims[0] reduced_shape = x.shape - batch_dim scale = mtf.get_variable(x.mesh, "batch_norm_scale", mtf.Shape([batch_dim]), initializer=tf.ones_initializer(), activation_dtype=x.dtype) bias = mtf.get_variable(x.mesh, "batch_norm_bias", mtf.Shape([batch_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) moving_mean = mtf.get_variable( x.mesh, "moving_mean", reduced_shape, initializer=tf.random_normal_initializer(stddev=1.0), activation_dtype=x.dtype, trainable=False) moving_variance = mtf.get_variable(x.mesh, "moving_variance", reduced_shape, initializer=tf.ones_initializer(), activation_dtype=x.dtype, trainable=False) # At training time, calculate mean and variance and normalize across batch # dim. if is_training: mean = mtf.reduce_mean(x, output_shape=reduced_shape) variance = mtf.reduce_mean(mtf.square(x - mean), output_shape=reduced_shape) norm_x = (x - mean) * mtf.rsqrt(variance + epsilon) # Update running mean and running variance. moving_mean = mtf.assign( moving_mean, momentum * moving_mean + (1 - momentum) * mean) moving_variance = mtf.assign( moving_variance, momentum * moving_variance + (1 - momentum) * variance) else: # At eval and test time, use the running mean and variance. norm_x = (x - moving_mean) * mtf.rsqrt(moving_variance + epsilon) return norm_x * scale + bias
def mnist_model(image, labels, mesh): """The model. Args: image: tf.Tensor with shape [batch, 28*28] labels: a tf.Tensor with shape [batch] and dtype tf.int32 mesh: a mtf.Mesh Returns: logits: a tf.Tensor with shape [batch, 10] loss: a mtf.Tensor with shape [] """ batch_dim = mtf.Dimension("batch", FLAGS.batch_size) rows_dim = mtf.Dimension("rows", 28) cols_dim = mtf.Dimension("cols", 28) classes_dim = mtf.Dimension("classes", 10) one_channel_dim = mtf.Dimension("one_channel", 1) x = mtf.import_tf_tensor(mesh, tf.reshape(image, [-1, 28, 28]), mtf.Shape([batch_dim, rows_dim, cols_dim])) x = mtf.reshape(x, [batch_dim, rows_dim, cols_dim, one_channel_dim]) # add some convolutional layers to demonstrate that convolution works. # TODO(noam): get spatially-partitioned convolution working. fh_dim = mtf.Dimension("fh", 3) fw_dim = mtf.Dimension("fw", 3) filters1_dim = mtf.Dimension("filters1", 32) filters2_dim = mtf.Dimension("filters2", 32) kernel1 = mtf.get_variable(mesh, "kernel1", [fh_dim, fw_dim, one_channel_dim, filters1_dim]) kernel2 = mtf.get_variable(mesh, "kernel2", [fh_dim, fw_dim, filters1_dim, filters2_dim]) f1 = mtf.relu(mtf.conv2d(x, kernel1)) f2 = mtf.relu(mtf.conv2d(f1, kernel2)) x = mtf.reduce_mean(f2, reduced_dim=filters2_dim) # add some fully-connected dense layers. hidden_dim1 = mtf.Dimension("hidden1", FLAGS.hidden_size) hidden_dim2 = mtf.Dimension("hidden2", FLAGS.hidden_size) h1 = mtf_layers.dense(x, hidden_dim1, reduced_dims=[rows_dim, cols_dim], activation=mtf.relu, name="hidden1") h2 = mtf_layers.dense(h1, hidden_dim2, activation=mtf.relu, name="hidden2") logits = mtf_layers.dense(h2, classes_dim, name="logits") if labels is None: loss = None else: labels = mtf.import_tf_tensor(mesh, labels, mtf.Shape([batch_dim])) loss = mtf_layers.softmax_cross_entropy_with_logits( logits, mtf.one_hot(labels, classes_dim), classes_dim) loss = mtf.reduce_mean(loss) return logits, loss
def multihead_self_attention_incremental(query_antecedent, prev_k, prev_v, step_num, name="multihead_attention"): """Incremental self-attention (one decode step). In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). Args: query_antecedent: a mtf.Tensor with shape [batch..., io_channels] prev_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] prev_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] step_num: mtf Scalar with dtype tf.int32 name: an optional string. Returns: y: A mtf.Tensor with shape [batch..., io_channels] new_k: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] new_v: mtf.Tensor with shape [batch..., heads, memory_length, kv_channels] Raises: ValueError: if the dimensions do not match. """ batch_dims = query_antecedent.shape.dims[:-1] io_channels = query_antecedent.shape.dims[-1] heads, memory_length, kv_channels = prev_k.shape.dims[-3:] with tf.variable_scope(name, default_name="multihead_attention"): q_var, k_var, v_var, o_var = multihead_attention_vars( query_antecedent.mesh, heads, io_channels, kv_channels, query_antecedent.dtype) memory_antecedent = query_antecedent q = mtf.einsum( [query_antecedent, q_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = mtf.einsum( [memory_antecedent, k_var], mtf.Shape(batch_dims + [heads, kv_channels])) v = mtf.einsum( [memory_antecedent, v_var], mtf.Shape(batch_dims + [heads, kv_channels])) k = prev_k + mtf.multiply( k, mtf.one_hot(step_num, memory_length), output_shape=prev_k.shape) v = prev_v + mtf.multiply( v, mtf.one_hot(step_num, memory_length), output_shape=prev_v.shape) mask = mtf.to_float(mtf.greater(mtf.range( query_antecedent.mesh, memory_length, dtype=tf.int32), step_num) ) * -1e9 o = dot_product_attention(q, k, v, mask) y = mtf.einsum([o, o_var], query_antecedent.shape) return y, k, v
def dense(x, output_dim, reduced_dims=None, expert_dims=None, use_bias=True, activation=None, name=None): """Dense layer doing (kernel*x + bias) computation. Args: x: a mtf.Tensor of shape [..., reduced_dims]. output_dim: a mtf.Dimension reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If omitted, we reduce the last dimension. expert_dims: an optional list of mtf.Dimension which represent different experts. Different experts get different weights. use_bias: a boolean, whether to add bias. activation: an optional function from mtf.Tensor to mtf.Tensor name: a string. variable scope. Returns: a mtf.Tensor of shape [..., output_dim]. """ if expert_dims is None: expert_dims = [] if reduced_dims is None: reduced_dims = x.shape.dims[-1:] w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim]) output_shape = mtf.Shape( [d for d in x.shape.dims if d not in reduced_dims] + [output_dim]) with tf.variable_scope(name, default_name="dense"): stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5 w = mtf.get_variable( x.mesh, "kernel", w_shape, initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) y = mtf.einsum([x, w], output_shape) if use_bias: b = mtf.get_variable(x.mesh, "bias", mtf.Shape(expert_dims + [output_dim]), initializer=tf.zeros_initializer(), activation_dtype=x.dtype) y += b if activation is not None: y = activation(y) return y
def first_block_attention(): """Compute attention for the first block.""" first_q = mtf.slice(q, 0, 1, num_blocks.name) first_k = mtf.slice(k, 0, 1, num_blocks.name) first_v = mtf.slice(v, 0, 1, num_blocks.name) block = first_q.shape.dims[2] first_logits = mtf.einsum( [first_q, first_k], mtf.Shape([batch, heads, block, blength, mlength])) weights = mtf.softmax(first_logits, mlength) first_output = mtf.einsum( [weights, first_v], mtf.Shape([batch, heads, block, blength, kv_channels])) return first_output
def testWeightsNonzero(self): inputs = tf.constant([[3, 1, 0], [1, 0, 0]]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", inputs.shape.as_list()[0]) channels_dim = mtf.Dimension("channels", inputs.shape.as_list()[1]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape=mtf.Shape( [batch_dim, channels_dim])) mtf_outputs = mtf_layers.weights_nonzero(mtf_inputs) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.weights_nonzero(inputs) tf_group = lowering.copy_masters_to_slices() self.evaluate(tf_group) actual, expected = self.evaluate([actual_outputs, expected_outputs]) self.assertAllEqual(actual, expected)
def testLayerNorm(self): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.layer_norm(mtf_inputs, dim=channels_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = common_layers.layer_norm(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual, expected = sess.run([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testDense(self, units, use_bias): batch = 2 channels = 3 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) depth_dim = mtf.Dimension("depth", units) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense(mtf_inputs, output_dim=depth_dim, reduced_dims=[channels_dim], activation=mtf.relu, use_bias=use_bias) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) expected_outputs = tf.keras.layers.Dense(units=units, activation=tf.nn.relu, use_bias=use_bias)(inputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual, expected = sess.run([actual_outputs, expected_outputs]) self.assertEqual(actual.shape, expected.shape)
def testDenseReluDense(self): batch = 2 channels = 3 hidden = 5 inputs = tf.random_normal([batch, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) channels_dim = mtf.Dimension("channels", channels) hidden_dim = mtf.Dimension("hidden", hidden) mtf_inputs = mtf.import_tf_tensor( mesh, inputs, shape=mtf.Shape([batch_dim, channels_dim])) mtf_outputs = mtf_layers.dense_relu_dense(mtf_inputs, hidden_channels=hidden_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, inputs.shape)
def multihead_attention_vars( mesh, heads, io_channels, kv_channels, activation_dtype): """Create Parameters for Multihead Attention. Args: mesh: a Mesh heads: a Dimension io_channels: a Dimension kv_channels: a Dimension activation_dtype: a tf.dtype Returns: q_var: a Tensor with shape [heads, io_channels, kv_channels] k_var: a Tensor with shape [heads, io_channels, kv_channels] v_var: a Tensor with shape [heads, io_channels, kv_channels] o_var: a Tensor with shape [heads, io_channels, kv_channels] """ qkvo = mtf.Dimension("qkvo", 4) qk_stddev = (io_channels.size ** -0.5) * (kv_channels.size ** -0.25) v_stddev = io_channels.size ** -0.5 o_stddev = (io_channels.size * heads.size) ** -0.5 def qkvo_initializer(shape, dtype=None, partition_info=None, verify_shape=None): del partition_info, verify_shape return tf.random_normal(shape, dtype=dtype) * tf.reshape( [qk_stddev, qk_stddev, v_stddev, o_stddev], [4, 1, 1, 1]) var = mtf.get_variable( mesh, "qkvo", mtf.Shape([qkvo, heads, io_channels, kv_channels]), initializer=qkvo_initializer, activation_dtype=activation_dtype) q_var, k_var, v_var, o_var = mtf.unstack(var, qkvo) return q_var, k_var, v_var, o_var
def dense_relu_dense(x, hidden_channels, dropout=0.0, dropout_broadcast_dims=None, name=None): """Hidden layer with ReLU activation followed by linear projection. The output has the same number of channels as the input. Args: x: a mtf.Tensor hidden_channels: a mtf.Dimension - channels in the hidden layer dropout: an optional float dropout_broadcast_dims: an optional list of mtf.Dimension name: an optional string Returns: a mtf.Tensor with the same shape as x. """ with tf.variable_scope(name, default_name="dense_relu_dense"): io_channels = x.shape.dims[-1] stddev = (hidden_channels.size * io_channels.size) ** -0.25 io = mtf.Dimension("io", 2) w = mtf.get_variable( x.mesh, "kernel", mtf.Shape([io, io_channels, hidden_channels]), initializer=tf.random_normal_initializer(stddev=stddev), activation_dtype=x.dtype) wi, wo = mtf.unstack(w, io) h = mtf.relu(mtf.einsum([x, wi])) if dropout != 0.0: h = mtf.dropout(h, 1.0 - dropout, noise_shape=h.shape - dropout_broadcast_dims) return mtf.einsum([h, wo])
def multihead_encdec_attention_incremental(query_antecedent, q_var, o_var, k, v, mask, name="multihead_attention"): """Incremental attention over encoder (one decode step). In order to use only one variable containing the four weight matrices packed together, we insist that the query and memory antecedents have the same dimensionality (io_channels) and that the keys and values have the same dimensionality (kv_channels). memory_dims is a subset of query_dims Args: query_antecedent: a mtf.Tensor with shape query_dims + [io_channels] q_var: a mtf.Tensor with shape [heads, io_channels, kv_channels] o_var: a mtf.Tensor with shape [heads, io_channels, kv_channels] k: memory_dims + [heads, memory_length, kv_channels] v: memory_dims + [heads, memory_length, kv_channels] mask: mask Tensor (see attention_mask()) name: an optional string. Returns: A mtf.Tensor with shape [batch, qlen, io_channels] """ heads, _, kv_channels = k.shape.dims[-3:] query_dims = query_antecedent.shape.dims[:-1] with tf.variable_scope(name, default_name="multihead_attention"): q = mtf.einsum([query_antecedent, q_var], mtf.Shape(query_dims + [heads, kv_channels])) o = dot_product_attention(q, k, v, mask) return mtf.einsum([o, o_var], query_antecedent.shape)
def testDotProductAttention(self, batch, heads, length_q, length_kv, depth_k, depth_v): query = tf.random_normal([batch, heads, length_q, depth_k]) key = tf.random_normal([batch, heads, length_kv, depth_k]) value = tf.random_normal([batch, heads, length_kv, depth_v]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) heads_dim = mtf.Dimension("heads", heads) length_q_dim = mtf.Dimension("length_q", length_q) length_kv_dim = mtf.Dimension("length_kv", length_kv) depth_k_dim = mtf.Dimension("depth_k", depth_k) depth_v_dim = mtf.Dimension("depth_v", depth_v) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, heads_dim, length_q_dim, depth_k_dim])) mtf_key = mtf.import_tf_tensor( mesh, key, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_k_dim])) mtf_value = mtf.import_tf_tensor( mesh, value, shape=mtf.Shape([batch_dim, heads_dim, length_kv_dim, depth_v_dim])) mtf_outputs = mtf_layers.dot_product_attention(mtf_query, mtf_key, mtf_value, mask=None) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, heads, length_q, depth_v))
def my_gather(tensor): return mtf.gather(tensor, top_beam_index, beam_dim, output_shape=mtf.Shape([ double_beam if d == beam_dim else d for d in tensor.shape.dims ]))
def gather(tensor, name): with tf.name_scope(prefix + name): output_shape = mtf.Shape([ beam_dim if d == old_beam_dim else d for d in tensor.shape.dims ]) return mtf.gather(tensor, topk_indices, old_beam_dim, output_shape=output_shape)
def testMaskedLocalAttention1D(self, batch, length, io_channels, kv_channels, heads, block_length): length_q = length length_m = length query = tf.random_normal([batch, length_q, io_channels]) memory = tf.random_normal([batch, length_m, io_channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_q_dim = mtf.Dimension("length_q", length_q) length_m_dim = mtf.Dimension("length_m", length_m) io_channels_dim = mtf.Dimension("io_channels", io_channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_q_dim, io_channels_dim])) mtf_memory = mtf.import_tf_tensor( mesh, memory, shape=mtf.Shape([batch_dim, length_m_dim, io_channels_dim])) mtf_outputs = mtf_layers.masked_local_attention_1d( mtf_query, mtf_memory, kv_channels=kv_channels_dim, heads=heads_dim, block_length=block_length) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, (batch, length_q, io_channels))
class MeshTensorFlowTest(parameterized.TestCase, tf.test.TestCase): @parameterized.parameters( (mtf.Dimension(name="x", size=5),), (("x", 5),), ) def testConvertToDimension(self, inputs): dimension = mtf.convert_to_dimension(inputs) self.assertEqual(dimension.name, "x") self.assertEqual(dimension.size, 5) def testConvertToDimensionGenericInputs(self): dimension = mtf.convert_to_dimension(None) self.assertEqual(dimension, None) with self.assertRaises(TypeError): mtf.convert_to_dimension(5) @parameterized.parameters( (mtf.Shape([mtf.Dimension(name="x", size=4), mtf.Dimension(name="y", size=8)]),), ("x:4;y:8",), ("x:4.y:8",), ("x:4 y:8",), ("x:4,y:8",), ) def testConvertToShape(self, inputs): shape = mtf.convert_to_shape(inputs) self.assertEqual(shape, mtf.Shape([mtf.Dimension(name="x", size=4), mtf.Dimension(name="y", size=8)])) def testConvertToShapeGenericInputs(self): shape = mtf.convert_to_shape(None) self.assertEqual(shape, None) with self.assertRaises(ValueError): mtf.convert_to_shape("x;4") @parameterized.parameters( (mtf.LayoutRules([("d_ff", "model"), ("heads", "model")]),), ("d_ff:model;heads:model",), ("d_ff:model.heads:model",), ("d_ff:model heads:model",), ("d_ff:model,heads:model",), ([("d_ff", "model"), ("heads", "model")],), ) def testConvertToLayoutRules(self, inputs): layout_rules = mtf.convert_to_layout_rules(inputs) self.assertEqual( layout_rules._pairs, mtf.LayoutRules([("d_ff", "model"), ("heads", "model")])._pairs) def testConvertToLayoutRulesGenericInputs(self): with self.assertRaises(ValueError): mtf.convert_to_layout_rules("d_ff;heads")
def toy_model(features, mesh): """A toy model implemented by mesh tensorlfow.""" batch_dim = mtf.Dimension('batch', FLAGS.batch_size) hidden_dim = mtf.Dimension('hidden', FLAGS.hidden_size) io_dim = mtf.Dimension('io', FLAGS.io_size) x = mtf.import_tf_tensor(mesh, features, mtf.Shape([batch_dim, io_dim])) h = mtf_layers.dense(x, hidden_dim, name='layer1', use_bias=False) y = mtf_layers.dense(h, io_dim, name='layer2', use_bias=False) loss = mtf.reduce_sum(mtf.square(y - x)) return y, loss
def testMeshImpl(self): shape = mtf.Shape([mtf.Dimension("batch", 4), mtf.Dimension("model", 8)]) layout_rules = mtf.LayoutRules([("batch", "batch"), ("d_ff", "model"), ("heads", "model")]) mesh_impl = mtf.MeshImpl(shape=shape, layout_rules=layout_rules) self.assertEqual(mesh_impl.shape, shape) self.assertEqual(mesh_impl.ndims, len(shape)) self.assertEqual(mesh_impl.layout_rules, layout_rules) self.assertEqual(mesh_impl.size, shape.size) self.assertTrue(mesh_impl.supports_control_dependencies) batch = mtf.Dimension("batch", 128) length = mtf.Dimension("length", 500) d_ff = mtf.Dimension("d_ff", 2048) heads = mtf.Dimension("heads", 8) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(batch), 0) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(d_ff), 1) self.assertEqual(mesh_impl.tensor_dimension_to_mesh_axis(heads), 1) self.assertEqual(mesh_impl.tensor_layout(mtf.Shape([batch, length, d_ff])), mtf.TensorLayout([0, None, 1]))
def dot_product_attention(q, k, v, mask, dropout=0.0, dropout_broadcast_dims=None): """Dot-product attention. Args: q: Tensor with shape [...., length_q, depth_k]. Typically leading dimensions are [batch, heads]. k: Tensor with shape [..., length_kv, depth_k]. Leading dimensions must match with q. v: Tensor with shape [..., length_kv, depth_v] Leading dimensions must match with q. mask: mask Tensor (see attention_mask()) dropout: a float. dropout_broadcast_dims: an optional list of mtf.Dimension Returns: Tensor with shape [..., length_q, depth_v]. """ length_kv = k.shape.dims[-2] logits_shape = mtf.Shape(q.shape.dims[:-1] + [length_kv]) logits = mtf.einsum([q, k], logits_shape) if mask is not None: logits += mask weights = mtf.softmax(logits, length_kv) if dropout != 0.0: weights = mtf.dropout(weights, 1.0 - dropout, noise_shape=weights.shape - dropout_broadcast_dims) depth_v = v.shape.dims[-1] outputs_shape = mtf.Shape(q.shape.dims[:-1] + [depth_v]) outputs = mtf.einsum([weights, v], outputs_shape) return outputs
def testLowering(self): graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") inputs = tf.constant(0.) mtf_inputs = mtf.import_tf_tensor(mesh, tf_tensor=inputs, shape=mtf.Shape([])) mesh_impl = placement_mesh_impl.PlacementMeshImpl( shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) outputs = lowering.export_to_tf_tensor(mtf_inputs) inputs_value, outputs_value = self.evaluate([inputs, outputs]) self.assertEqual(inputs_value, outputs_value) # Check that methods run without error. _ = lowering.copy_masters_to_slices() _ = lowering.copy_slices_to_masters()
def testMultiheadAttention(self, kv_channels, heads): batch = 2 length = 8 channels = 3 query = tf.random_normal([batch, length, channels]) graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") batch_dim = mtf.Dimension("batch", batch) length_dim = mtf.Dimension("length", length) channels_dim = mtf.Dimension("channels", channels) kv_channels_dim = mtf.Dimension("kv_channels", kv_channels) heads_dim = mtf.Dimension("heads", heads) mtf_query = mtf.import_tf_tensor( mesh, query, shape=mtf.Shape([batch_dim, length_dim, channels_dim])) mtf_outputs = mtf_layers.multihead_attention( mtf_query, memory_antecedent=None, mask=None, kv_channels=kv_channels_dim, heads=heads_dim) mesh_impl = placement_mesh_impl.PlacementMeshImpl(shape=[], layout={}, devices=[""]) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) actual_outputs = lowering.export_to_tf_tensor(mtf_outputs) tf_group = lowering.copy_masters_to_slices() init = tf.global_variables_initializer() with self.test_session() as sess: sess.run(init) sess.run(tf_group) actual = sess.run(actual_outputs) self.assertEqual(actual.shape, query.shape)
def estimator_model_fn(cls, hparams, features, labels, mode, config=None, params=None, decode_hparams=None, use_tpu=False, xla_compile=False): del xla_compile hparams = copy.deepcopy(hparams) hparams.use_tpu = use_tpu # merge decode_hparams into hparams if present if mode == tf.estimator.ModeKeys.PREDICT and decode_hparams is not None: for k, v in six.iteritems(decode_hparams.values()): if hasattr(hparams, k) and getattr(hparams, k) != v: tf.logging.warning( "Overriding hparams.%s with %s from decode_hparams" % (k, v)) setattr(hparams, k, v) # Instantiate model data_parallelism = None if not use_tpu and config: data_parallelism = config.data_parallelism model = cls(hparams, mode, data_parallelism=data_parallelism, decode_hparams=decode_hparams) global_step = tf.train.get_global_step() graph = mtf.Graph() mesh = mtf.Mesh(graph, "my_mesh") mesh_shape = mtf.convert_to_shape(hparams.mesh_shape) layout_rules = mtf.convert_to_layout_rules(hparams.layout) if use_tpu: mesh_devices = [""] * mesh_shape.size mesh_impl = simd_mesh_impl.SimdMeshImpl( mesh_shape, layout_rules, mesh_devices, params["context"].device_assignment) else: if len(data_parallelism.ps_devices) == 1: mesh_devices = [""] * mesh_shape.size else: assert len(data_parallelism.ps_devices) == mesh_shape.size mesh_devices = data_parallelism.ps_devices mesh_impl = placement_mesh_impl.PlacementMeshImpl( mesh_shape, layout_rules, mesh_devices) # PREDICT mode if mode == tf.estimator.ModeKeys.PREDICT: return model.estimator_spec_predict(features, mesh, mesh_impl, use_tpu) logits, loss = model.mtf_model_fn(features, mesh) if use_tpu and logits is not None: logits = mtf.anonymize(logits) # TRAIN mode if mode == tf.estimator.ModeKeys.TRAIN: var_grads = mtf.gradients( [loss], [v.outputs[0] for v in graph.trainable_variables]) lr = learning_rate.learning_rate_schedule(hparams) mtf_lr = mtf.import_tf_tensor( mesh, tf.convert_to_tensor(lr, dtype=tf.float32), mtf.Shape([])) optimizer = mtf_optimize.make_optimizer(hparams, mtf_lr) update_ops = [] for grad, var in zip(var_grads, graph.trainable_variables): update_ops.extend(optimizer.apply_grad(grad, var)) lowering = mtf.Lowering(graph, {mesh: mesh_impl}) tf_loss = lowering.export_to_tf_tensor(loss) tf_loss = tf.to_float(tf_loss) if logits and mode != tf.estimator.ModeKeys.TRAIN: tf_logits = lowering.export_to_tf_tensor(logits) if mode == tf.estimator.ModeKeys.TRAIN: tf_update_ops = [ lowering.lowered_operation(op) for op in update_ops ] tf_update_ops.append(tf.assign_add(global_step, 1)) # tf.logging.info("tf_update_ops: {}".format(tf_update_ops)) train_op = tf.group(tf_update_ops) with mtf_utils.outside_all_rewrites(): # Copy master variables to slices. Must be called first. restore_hook = mtf.MtfRestoreHook(lowering) saver = tf.train.Saver(tf.global_variables(), sharded=True, max_to_keep=10, keep_checkpoint_every_n_hours=2, defer_build=False, save_relative_paths=True) tf.add_to_collection(tf.GraphKeys.SAVERS, saver) saver_listener = mtf.MtfCheckpointSaverListener(lowering) saver_hook = tf.train.CheckpointSaverHook( hparams.model_dir, save_steps=1000, saver=saver, listeners=[saver_listener]) # EVAL mode if mode == tf.estimator.ModeKeys.EVAL: tf_logits = lowering.export_to_tf_tensor(logits) return model.estimator_spec_eval(features, tf_logits, labels, tf_loss, restore_hook, use_tpu) if use_tpu: _remove_summaries() return tpu_estimator.TPUEstimatorSpec( mode=tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_hooks=[restore_hook, saver_hook]) else: return tf.estimator.EstimatorSpec( tf.estimator.ModeKeys.TRAIN, loss=tf_loss, train_op=train_op, training_chief_hooks=[restore_hook, saver_hook])