예제 #1
0
    def CreateMesh(mesh_shape):
        nonlocal mesh_id
        num_nodes = ((utils.Prod(mesh_shape) + gpus_per_node - 1) //
                     gpus_per_node)

        mesh = mtf.Mesh(graph, f'mesh{mesh_id}')
        meshes.append(mesh)
        mesh_id += 1
        mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape,
                                               gpus_per_node=gpus_per_node)
        return mesh
예제 #2
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    assert num_gpus % num_nodes == 0
    assert num_gpus % 2 == 0
    gpus_per_node = num_gpus // num_nodes
    devices = utils.GetDeviceList(num_gpus, gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh1')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh2')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus],
            devices=utils.FlattenList(utils.TransposeLists(
                [devices[:num_gpus//2], devices[num_gpus//2:]])),
            gpus_per_node=gpus_per_node)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape)
    shape = shape.rename_dimension('axis0', utils.RandName())
    mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
예제 #3
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    mesh = mtf.Mesh(graph, 'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], 
            gpus_per_node=num_gpus // num_nodes)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape)
    mtf_labels = mtf.import_tf_tensor(mesh, labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
예제 #4
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}
    gpus_per_node = num_gpus // num_nodes
    devices = utils.GetDeviceList(num_gpus, gpus_per_node)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    if num_gpus == 4:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 4],
                       [2, 1, 1],
                       [2, 1, 1],
                       [1, 4, 1]]

    elif num_gpus == 8:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 8],
                       [4, 1, 1],
                       [4, 1, 1],
                       [1, 8, 1]]

    elif num_gpus == 16:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 16],
                       [2, 2, 2],
                       [2, 2, 2],
                       [1, 16, 1]]

    elif num_gpus == 32:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 32],
                       [4, 2, 2],
                       [4, 2, 2],
                       [1, 32, 1]]

    elif num_gpus == 64:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 64],
                       [8, 2, 2],
                       [8, 2, 2],
                       [1, 64, 1]]

    else:
        assert False

    assert mesh_shapes[1] == mesh_shapes[2]
    assert (utils.Prod(mesh_shapes[1]) == utils.Prod(mesh_shapes[2]) ==
            num_gpus // 2)
    assert (num_nodes == 1) or (num_nodes % 2 == 0)
    half_devices0 = devices[:(num_gpus // 2)]
    half_devices1 = devices[(num_gpus // 2):]
    mesh_devices = [devices, half_devices0, half_devices1,
            half_devices1 + half_devices0]

    for i, (mesh_shape, ds) in enumerate(zip(mesh_shapes, mesh_devices)):
        mesh = mtf.Mesh(graph, 'mesh' + str(i))
        meshes.append(mesh)
        mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, devices=ds,
                gpus_per_node=gpus_per_node)

    mtf_shape = utils.ConvertToShape([('axis0', batch_size)] +
        inputs.shape.as_list()[1:])
    mtf_inputs = mtf.import_tf_tensor(meshes[0], inputs, mtf_shape)
    mtf_labels = mtf.import_tf_tensor(meshes[-1], labels, mtf_shape)
    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
def GetMeshImpl(dev_cnts, devices=None):
    if isinstance(devices, list):
        devices = [utils.GetDeviceStr(0, d) for d in devices]
    return utils.GetMeshImpl(dev_cnts, devices)