def CreateMesh(mesh_shape): nonlocal mesh_id num_nodes = ((utils.Prod(mesh_shape) + gpus_per_node - 1) // gpus_per_node) mesh = mtf.Mesh(graph, f'mesh{mesh_id}') meshes.append(mesh) mesh_id += 1 mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, gpus_per_node=gpus_per_node) return mesh
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} assert num_gpus % num_nodes == 0 assert num_gpus % 2 == 0 gpus_per_node = num_gpus // num_nodes devices = utils.GetDeviceList(num_gpus, gpus_per_node) mesh = mtf.Mesh(graph, f'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh1') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2], devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node) mesh = mtf.Mesh(graph, f'mesh2') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], devices=utils.FlattenList(utils.TransposeLists( [devices[:num_gpus//2], devices[num_gpus//2:]])), gpus_per_node=gpus_per_node) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape) shape = shape.rename_dimension('axis0', utils.RandName()) mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} mesh = mtf.Mesh(graph, 'mesh0') meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], gpus_per_node=num_gpus // num_nodes) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape shape = utils.ConvertToShape([('axis0', batch_size), inputs.shape.as_list()[1]]) mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape) mtf_labels = mtf.import_tf_tensor(mesh, labels, shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size): graph = mtf.Graph() meshes = [] mesh_to_impl = {} gpus_per_node = num_gpus // num_nodes devices = utils.GetDeviceList(num_gpus, gpus_per_node) assert len(inputs.shape) == 2 assert inputs.shape == labels.shape if num_gpus == 4: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 4], [2, 1, 1], [2, 1, 1], [1, 4, 1]] elif num_gpus == 8: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 8], [4, 1, 1], [4, 1, 1], [1, 8, 1]] elif num_gpus == 16: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 16], [2, 2, 2], [2, 2, 2], [1, 16, 1]] elif num_gpus == 32: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 32], [4, 2, 2], [4, 2, 2], [1, 32, 1]] elif num_gpus == 64: # Mesh_shape: batch_dim, n_dim, k_dim mesh_shapes = [[1, 1, 64], [8, 2, 2], [8, 2, 2], [1, 64, 1]] else: assert False assert mesh_shapes[1] == mesh_shapes[2] assert (utils.Prod(mesh_shapes[1]) == utils.Prod(mesh_shapes[2]) == num_gpus // 2) assert (num_nodes == 1) or (num_nodes % 2 == 0) half_devices0 = devices[:(num_gpus // 2)] half_devices1 = devices[(num_gpus // 2):] mesh_devices = [devices, half_devices0, half_devices1, half_devices1 + half_devices0] for i, (mesh_shape, ds) in enumerate(zip(mesh_shapes, mesh_devices)): mesh = mtf.Mesh(graph, 'mesh' + str(i)) meshes.append(mesh) mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, devices=ds, gpus_per_node=gpus_per_node) mtf_shape = utils.ConvertToShape([('axis0', batch_size)] + inputs.shape.as_list()[1:]) mtf_inputs = mtf.import_tf_tensor(meshes[0], inputs, mtf_shape) mtf_labels = mtf.import_tf_tensor(meshes[-1], labels, mtf_shape) return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def GetMeshImpl(dev_cnts, devices=None): if isinstance(devices, list): devices = [utils.GetDeviceStr(0, d) for d in devices] return utils.GetMeshImpl(dev_cnts, devices)