def model(params, inputs, labels): # MTF mesh assert len(inputs.shape) == 2 graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes( inputs, labels, params.num_nodes, params.num_gpus, params.batch_size) embed_mesh, lstm0_mesh, lstm1_mesh, proj_mesh = meshes batch_dim_name, n_dim_name, k_dim_name = 'axis0', 'axis1', 'axis2' # RNN weights num_units = params.num_units w_shape = utils.ConvertToShape([(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)]) rnn_w0 = mtf.get_variable(lstm0_mesh, 'rnn_w0', w_shape) rnn_w1 = mtf.get_variable(lstm1_mesh, 'rnn_w1', w_shape) # RNN initial states h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(k_dim_name, num_units)]) c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(n_dim_name, num_units)]) states0 = [mtf.zeros(lstm0_mesh, h_shape), mtf.zeros(lstm0_mesh, c_shape)] states1 = [mtf.zeros(lstm1_mesh, h_shape), mtf.zeros(lstm1_mesh, c_shape)] # Model - embedding vocab_dim = mtf.Dimension(k_dim_name, params.vocab_size) embed_dim = mtf.Dimension(n_dim_name, params.num_units) assert mtf_inputs.mesh == embed_mesh embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim, tf.float32) assert embedding.shape[-1].name == n_dim_name shape = embedding.shape.rename_dimension(n_dim_name, k_dim_name) embedding = mesh_trans.ReplaceMeshWithIndependentAxes( embedding, lstm0_mesh, shape.dimension_names) # Model - RNN [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units, states=states0 + states1).outputs assert y.mesh == lstm1_mesh assert y.shape[-1].name == k_dim_name assert mesh_to_impl[proj_mesh].shape[-1] == mtf.Dimension(k_dim_name, 1) rand_dim_name = utils.RandName() y = mt.rename_dimension(y, k_dim_name, rand_dim_name) shape = y.shape.rename_dimension(rand_dim_name, k_dim_name) y = mesh_trans.ReplaceMeshWithIndependentAxes( y, proj_mesh, shape.dimension_names) # Model - Dense + loss assert y.shape[-1].name == k_dim_name vocab_dim = mtf.Dimension(n_dim_name, params.vocab_size) y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:], use_bias=False) assert mtf_labels.mesh == proj_mesh mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( y, mtf_labels, vocab_dim) mtf_loss = mtf.reduce_mean(mtf_cross_ent) model.soft_placement = True return graph, mesh_to_impl, mtf_loss
def encoder_decoder(inp, enc_out, vocab_dim, name): with tf.variable_scope(name): # Embedding embed = mtf.layers.embedding(inp, vocab_dim, model_dim, tf.float32, name=f'{name}_embedding') if strategy == 1: check_distribution(embed, meshes[1], {}) shape = embed.shape shape = shape.rename_dimension(shape[0].name, 'axis0') embed = mesh_trans.ReplaceMeshWithIndependentAxes( embed, meshes[0], dim_names=shape.dimension_names) check_distribution(embed, meshes[0], {0: 'axis0'}) # Positional encoding x = positional_encoding(embed) check_distribution(x, meshes[0], {0: 'axis0'}) # Encoder/decoder layers for i in range(params.nx): # Multihead attention y = mtf.layers.multihead_attention(x, None, None, d_k_dim, heads_dim, dropout=0.5, name=f'{name}_att_{i}') x = add_norm(x, y, name=f'{name}_att_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) if enc_out is not None: y = mtf.layers.multihead_attention(x, enc_out, None, d_k_dim, heads_dim, dropout=0.5, name=f'{name}_att2_{i}') x = add_norm(x, y, name=f'{name}_att2_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) # Feed forward y = mtf.layers.dense_relu_dense(x, ff_dim, dropout=0.5, name=f'{name}_ff_{i}') x = add_norm(x, y, name=f'{name}_ff_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) return x
def Transpose1(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]), *shape[2:]]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), RandName(), 'axis0', 'axis1']) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def Contract2(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2, 4]), \ mesh1:GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1', RandName()]) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def MoreDevices(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([2]), \ mesh1:GetMeshImpl([8])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:-1] + [('axis0', shape[-1])]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', RandName(), RandName()]) Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
def DependentAxes(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]), *shape[2:]]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1', RandName()]) try: Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr) assert False # This run should fail and throw ValueError except ValueError: return
def WrongShape(in_tsr): graph = mtf.Graph() mesh0 = mtf.Mesh(graph, 'mesh0') mesh1 = mtf.Mesh(graph, 'mesh1') mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \ mesh1:GetMeshImpl([8])} shape = in_tsr.get_shape().as_list() mtf_shape = GetShape(shape[:-2] + [('axis0', shape[2]), ('axis0', shape[3])]) mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape) mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes( mtf_in_tsr, mesh1, [RandName(), 'axis0', RandName(), RandName()]) try: Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr) assert False # This test should fail with ValueError except ValueError: return
def Transformer(src, tgt, params, src_vocab_size, tgt_vocab_size, strategy, num_nodes, num_gpus): graph, meshes, mesh_to_impl, mtf_src, mtf_tgt = CreateMeshes( strategy, src, tgt, num_nodes, num_gpus, params) src_vocab_size = utils.RoundUp(src_vocab_size, num_gpus) tgt_vocab_size = utils.RoundUp(tgt_vocab_size, num_gpus) # mtf dimensions if strategy == 0: src_vocab_dim = mtf.Dimension(RandName(), src_vocab_size) tgt_vocab_dim = mtf.Dimension(RandName(), tgt_vocab_size) model_dim = mtf.Dimension(RandName(), params.d_model) d_k_dim = mtf.Dimension(RandName(), params.d_k) heads_dim = mtf.Dimension(RandName(), params.heads) ff_dim = mtf.Dimension(RandName(), params.d_ff) elif strategy == 1: src_vocab_dim = mtf.Dimension('axis0', src_vocab_size) tgt_vocab_dim = mtf.Dimension('axis0', tgt_vocab_size) model_dim = mtf.Dimension(RandName(), params.d_model) d_k_dim = mtf.Dimension(RandName(), params.d_k) heads_dim = mtf.Dimension('axis1', params.heads) ff_dim = mtf.Dimension('axis1', params.d_ff) elif strategy == 2: src_vocab_dim = mtf.Dimension('axis1', src_vocab_size) tgt_vocab_dim = mtf.Dimension('axis1', tgt_vocab_size) model_dim = mtf.Dimension(RandName(), params.d_model) d_k_dim = mtf.Dimension(RandName(), params.d_k) heads_dim = mtf.Dimension('axis1', params.heads) ff_dim = mtf.Dimension('axis1', params.d_ff) else: assert False seq_len_dim = mtf_src.shape[-1] assert mtf_src.shape[-1] == mtf_tgt.shape[-1] if strategy == 1: check_distribution(mtf_src, meshes[1], {}) check_distribution(mtf_tgt, meshes[1], {}) else: check_distribution(mtf_src, meshes[0], {0: 'axis0'}) check_distribution(mtf_tgt, meshes[0], {0: 'axis0'}) def encoder_decoder(inp, enc_out, vocab_dim, name): with tf.variable_scope(name): # Embedding embed = mtf.layers.embedding(inp, vocab_dim, model_dim, tf.float32, name=f'{name}_embedding') if strategy == 1: check_distribution(embed, meshes[1], {}) shape = embed.shape shape = shape.rename_dimension(shape[0].name, 'axis0') embed = mesh_trans.ReplaceMeshWithIndependentAxes( embed, meshes[0], dim_names=shape.dimension_names) check_distribution(embed, meshes[0], {0: 'axis0'}) # Positional encoding x = positional_encoding(embed) check_distribution(x, meshes[0], {0: 'axis0'}) # Encoder/decoder layers for i in range(params.nx): # Multihead attention y = mtf.layers.multihead_attention(x, None, None, d_k_dim, heads_dim, dropout=0.5, name=f'{name}_att_{i}') x = add_norm(x, y, name=f'{name}_att_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) if enc_out is not None: y = mtf.layers.multihead_attention(x, enc_out, None, d_k_dim, heads_dim, dropout=0.5, name=f'{name}_att2_{i}') x = add_norm(x, y, name=f'{name}_att2_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) # Feed forward y = mtf.layers.dense_relu_dense(x, ff_dim, dropout=0.5, name=f'{name}_ff_{i}') x = add_norm(x, y, name=f'{name}_ff_{i}_norm') check_distribution(x, meshes[0], {0: 'axis0'}) return x # Encoder/Decoder enc_out = encoder_decoder(mtf_src, None, src_vocab_dim, 'encoder') check_distribution(enc_out, meshes[0], {0: 'axis0'}) enc_out = mt.rename_dimension(enc_out, enc_out.shape[1].name, RandName()) dec_out = encoder_decoder(mtf_tgt, enc_out, tgt_vocab_dim, 'decoder') # Loss function with tf.variable_scope('loss'): check_distribution(dec_out, meshes[0], {0: 'axis0'}) if strategy == 1: shape = dec_out.shape.rename_dimension('axis0', mtf_tgt.shape[0].name) dec_out = mesh_trans.ReplaceMeshWithIndependentAxes( dec_out, meshes[1], dim_names=shape.dimension_names) check_distribution(dec_out, meshes[1], {}) out = mtf.layers.dense(dec_out, tgt_vocab_dim, use_bias=False, reduced_dims=dec_out.shape[-1:], name='final_projection') assert out.mesh == mtf_tgt.mesh assert (out.shape.dims == mtf_tgt.shape.dims + [tgt_vocab_dim]) out = mtf.layers.softmax_cross_entropy_with_logits( out, mtf_tgt, tgt_vocab_dim) loss = mtf.reduce_mean(out) return graph, mesh_to_impl, loss