def model(params, inputs, labels): # MTF mesh assert len(inputs.shape) == 2 graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes( inputs, labels, params.num_nodes, params.num_gpus, params.batch_size) embed_mesh, lstm0_mesh, lstm1_mesh, proj_mesh = meshes batch_dim_name, n_dim_name, k_dim_name = 'axis0', 'axis1', 'axis2' # RNN weights num_units = params.num_units w_shape = utils.ConvertToShape([(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)]) rnn_w0 = mtf.get_variable(lstm0_mesh, 'rnn_w0', w_shape) rnn_w1 = mtf.get_variable(lstm1_mesh, 'rnn_w1', w_shape) # RNN initial states h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(k_dim_name, num_units)]) c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(n_dim_name, num_units)]) states0 = [mtf.zeros(lstm0_mesh, h_shape), mtf.zeros(lstm0_mesh, c_shape)] states1 = [mtf.zeros(lstm1_mesh, h_shape), mtf.zeros(lstm1_mesh, c_shape)] # Model - embedding vocab_dim = mtf.Dimension(k_dim_name, params.vocab_size) embed_dim = mtf.Dimension(n_dim_name, params.num_units) assert mtf_inputs.mesh == embed_mesh embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim, tf.float32) assert embedding.shape[-1].name == n_dim_name shape = embedding.shape.rename_dimension(n_dim_name, k_dim_name) embedding = mesh_trans.ReplaceMeshWithIndependentAxes( embedding, lstm0_mesh, shape.dimension_names) # Model - RNN [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units, states=states0 + states1).outputs assert y.mesh == lstm1_mesh assert y.shape[-1].name == k_dim_name assert mesh_to_impl[proj_mesh].shape[-1] == mtf.Dimension(k_dim_name, 1) rand_dim_name = utils.RandName() y = mt.rename_dimension(y, k_dim_name, rand_dim_name) shape = y.shape.rename_dimension(rand_dim_name, k_dim_name) y = mesh_trans.ReplaceMeshWithIndependentAxes( y, proj_mesh, shape.dimension_names) # Model - Dense + loss assert y.shape[-1].name == k_dim_name vocab_dim = mtf.Dimension(n_dim_name, params.vocab_size) y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:], use_bias=False) assert mtf_labels.mesh == proj_mesh mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( y, mtf_labels, vocab_dim) mtf_loss = mtf.reduce_mean(mtf_cross_ent) model.soft_placement = True return graph, mesh_to_impl, mtf_loss
def model(params, inputs, labels): # Mtf mesh assert len(inputs.shape) == 2 graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes( inputs, labels, params.num_nodes, params.num_gpus, params.batch_size) # Embedding dimensions vocab_dim = mtf.Dimension(utils.RandName(), params.vocab_size) embed_dim = mtf.Dimension(utils.RandName(), params.num_units) batch_dim_name = mtf_inputs.shape[0].name k_dim_name = embed_dim.name n_dim_name = utils.RandName() # RNN weights num_units = params.num_units w_shape = utils.ConvertToShape( [(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)]) rnn_w0 = mtf.get_variable(meshes[0], 'rnn_w0', w_shape) rnn_w1 = mtf.get_variable(meshes[1], 'rnn_w1', w_shape) # RNN initial states h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(k_dim_name, num_units)]) c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size), mtf.Dimension(n_dim_name, num_units)]) states0 = [mtf.zeros(meshes[0], h_shape), mtf.zeros(meshes[0], c_shape)] states1 = [mtf.zeros(meshes[1], h_shape), mtf.zeros(meshes[1], c_shape)] # Model embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim, tf.float32) assert embedding.mesh == meshes[2] embedding = ReplaceRNNMesh(embedding, meshes[0]).outputs[0] [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units, states=states0+states1).outputs assert y.mesh == meshes[1] assert y.shape[0].name == 'axis0' y = mt.rename_dimension(y, 'axis0', mtf_labels.shape[0].name) y = mesh_trans.ReplaceMeshWithSimpleReplication(y, meshes[2]) vocab_dim = mtf.Dimension('axis0', params.vocab_size) y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:], use_bias=False) assert y.mesh == mtf_labels.mesh mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits( y, mtf_labels, vocab_dim) mtf_loss = mtf.reduce_mean(mtf_cross_ent) model.soft_placement = True return graph, mesh_to_impl, mtf_loss
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], gpus_per_node=num_gpus // num_nodes) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape) mtf_labels = mtf.import_tf_tensor(mesh, labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} assert num_gpus % num_nodes == 0 assert num_gpus % 2 == 0 gpus_per_node = num_gpus // num_nodes devices = utils.GetDeviceList(num_gpus, gpus_per_node) mesh = mtf.Mesh(graph, f'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh2') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], devices=utils.FlattenList(utils.TransposeLists( [devices[:num_gpus//2], devices[num_gpus//2:]])), gpus_per_node=gpus_per_node) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape) shape = shape.rename_dimension('axis0', utils.RandName()) mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} gpus_per_node = num_gpus // num_nodes devices = utils.GetDeviceList(num_gpus, gpus_per_node) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape if num_gpus == 4: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 4], [2, 1, 1], [2, 1, 1], [1, 4, 1]] elif num_gpus == 8: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 8], [4, 1, 1], [4, 1, 1], [1, 8, 1]] elif num_gpus == 16: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 16], [2, 2, 2], [2, 2, 2], [1, 16, 1]] elif num_gpus == 32: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 32], [4, 2, 2], [4, 2, 2], [1, 32, 1]] elif num_gpus == 64: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 64], [8, 2, 2], [8, 2, 2], [1, 64, 1]] else: assert False assert mesh_shapes[1] == mesh_shapes[2] assert (utils.Prod(mesh_shapes[1]) == utils.Prod(mesh_shapes[2]) == num_gpus // 2) assert (num_nodes == 1) or (num_nodes % 2 == 0) half_devices0 = devices[:(num_gpus // 2)] half_devices1 = devices[(num_gpus // 2):] mesh_devices = [devices, half_devices0, half_devices1, half_devices1 + half_devices0] for i, (mesh_shape, ds) in enumerate(zip(mesh_shapes, mesh_devices)): mesh = mtf.Mesh(graph, 'mesh' + str(i)) meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, devices=ds, gpus_per_node=gpus_per_node) mtf_shape = utils.ConvertToShape([('axis0', batch_size)] + inputs.shape.as_list()[1:]) mtf_inputs = mtf.import_tf_tensor(meshes[0], inputs, mtf_shape) mtf_labels = mtf.import_tf_tensor(meshes[-1], labels, mtf_shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def CreateMeshes(args, img, labels, num_nodes, num_gpus): h, w, ch = 299, 299, 3 graph = mtf.Graph() meshes = [] mesh_to_impl = {} strategy = args.strategy batch_size = args.batch_size gpus_per_node = (num_gpus // num_nodes) def Mesh(): mesh = mtf.Mesh(graph, 'mesh%d' % Mesh.idx) meshes.append(mesh) Mesh.idx += 1 return mesh Mesh.idx = 0 GetMeshImpl = functools.partial(utils.GetMeshImpl, gpus_per_node=gpus_per_node) if strategy == 0: mesh = Mesh() mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mtf_img = mtf.import_tf_tensor( mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor( mesh, labels, utils.ConvertToShape([('axis0', batch_size)])) elif strategy == 1: # mesh0 mesh = Mesh() mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) # mesh1 mesh = Mesh() if num_gpus == 4: mesh_to_impl[mesh] = GetMeshImpl([4, 1]) else: mesh_to_impl[mesh] = GetMeshImpl([num_gpus // 2, 2]) mtf_img = mtf.import_tf_tensor( meshes[0], img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor(meshes[1], labels, utils.ConvertToShape([batch_size])) elif strategy == 2: # mesh0 mesh = Mesh() mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mtf_img = mtf.import_tf_tensor( meshes[0], img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor(meshes[0], labels, utils.ConvertToShape([batch_size])) elif strategy == 3: mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mesh = mtf.Mesh(graph, 'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([2, num_gpus // 2]) mtf_img = mtf.import_tf_tensor( meshes[0], img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor( meshes[1], labels, utils.ConvertToShape([('axis0', batch_size)])) else: assert False return graph, meshes, mesh_to_impl, mtf_img, mtf_labels
def CreateMeshes(img, labels, num_nodes, num_gpus, args): h, w, ch = 227, 227, 3 graph = mtf.Graph() meshes = [] mesh_to_impl = {} strategy = args.strategy batch_size = args.batch_size gpus_per_node = (num_gpus // num_nodes) GetMeshImpl = functools.partial(utils.GetMeshImpl, gpus_per_node=gpus_per_node) if strategy == 0: mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mtf_img = mtf.import_tf_tensor( mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor( mesh, labels, utils.ConvertToShape([('axis0', batch_size)])) elif strategy == 1: mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) if num_gpus == 4: dim1, dim2 = 4, 1 elif num_gpus == 8: dim1, dim2 = 4, 2 elif num_gpus == 16: dim1, dim2 = 8, 2 elif num_gpus == 32: dim1, dim2 = 8, 4 elif num_gpus == 64: dim1, dim2 = 8, 8 else: assert False assert ((dim1 * dim2) == num_gpus) mesh = mtf.Mesh(graph, 'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([dim1, dim2]) mesh = mtf.Mesh(graph, 'mesh2') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([1, dim2]) mtf_img = mtf.import_tf_tensor( meshes[0], img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor(meshes[-1], labels, utils.ConvertToShape([batch_size])) elif strategy == 2: mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mtf_img = mtf.import_tf_tensor( mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor(mesh, labels, utils.ConvertToShape([batch_size])) elif strategy == 3: mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([num_gpus]) mesh = mtf.Mesh(graph, 'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = GetMeshImpl([2, num_gpus // 2]) mtf_img = mtf.import_tf_tensor( meshes[0], img, utils.ConvertToShape([('axis0', batch_size), h, w, ch])) mtf_labels = mtf.import_tf_tensor( meshes[1], labels, utils.ConvertToShape([('axis0', batch_size)])) else: assert False return graph, meshes, mesh_to_impl, mtf_img, mtf_labels
def CreateMeshes(strategy, src, tgt, num_nodes, num_gpus, params): graph = mtf.Graph() meshes = [] mesh_to_impl = {} gpus_per_node = num_gpus // num_nodes mesh_id = 0 def CreateMesh(mesh_shape): nonlocal mesh_id num_nodes = ((utils.Prod(mesh_shape) + gpus_per_node - 1) // gpus_per_node) mesh = mtf.Mesh(graph, f'mesh{mesh_id}') meshes.append(mesh) mesh_id += 1 mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, gpus_per_node=gpus_per_node) return mesh if strategy == 0: # Data-parallel mesh = CreateMesh([num_gpus]) shape = utils.ConvertToShape([('axis0', params.batch_size), params.max_seq_len]) mtf_src = mtf.import_tf_tensor(mesh, src, shape) mtf_tgt = mtf.import_tf_tensor(mesh, src, shape) elif strategy == 1: # Opt strategy from the tool if num_gpus == 4: dim1, dim2 = 4, 1 elif num_gpus == 8: dim1, dim2 = 8, 1 elif num_gpus == 16: if params.model_size == 'small': dim1, dim2 = 16, 1 else: dim1, dim2 = 8, 2 elif num_gpus == 32: dim1, dim2 = 16, 2 elif num_gpus == 64: if params.model_size == 'small': dim1, dim2 = 32, 2 else: dim1, dim2 = 16, 4 else: assert False assert ((dim1 * dim2) == num_gpus) mesh0 = CreateMesh([dim1, dim2]) mesh1 = CreateMesh([num_gpus]) shape = utils.ConvertToShape([params.batch_size, params.max_seq_len]) mtf_src = mtf.import_tf_tensor(mesh1, src, shape) mtf_tgt = mtf.import_tf_tensor(mesh1, tgt, shape) elif strategy == 2: # Strategy from mesh-tensorflow paper if num_gpus == 4: dim1, dim2 = 4, 1 elif num_gpus == 8: dim1, dim2 = 2, 4 elif num_gpus == 16: dim1, dim2 = 4, 4 elif num_gpus == 32: dim1, dim2 = 4, 8 elif num_gpus == 64: dim1, dim2 = 8, 8 else: assert False assert ((dim1 * dim2) == num_gpus) mesh = CreateMesh([dim1, dim2]) shape = utils.ConvertToShape([('axis0', params.batch_size), params.max_seq_len]) mtf_src = mtf.import_tf_tensor(mesh, src, shape) mtf_tgt = mtf.import_tf_tensor(mesh, src, shape) else: assert False mtf_src = mtf.cast(mtf_src, tf.int32) mtf_tgt = mtf.cast(mtf_tgt, tf.int32) return graph, meshes, mesh_to_impl, mtf_src, mtf_tgt