Example #1
0
def model(params, inputs, labels):
    # MTF mesh
    assert len(inputs.shape) == 2
    graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes(
            inputs, labels, params.num_nodes, params.num_gpus,
            params.batch_size)
    embed_mesh, lstm0_mesh, lstm1_mesh, proj_mesh = meshes
    batch_dim_name, n_dim_name, k_dim_name = 'axis0', 'axis1', 'axis2'

    # RNN weights
    num_units = params.num_units
    w_shape = utils.ConvertToShape([(k_dim_name, 2*num_units),
        (n_dim_name, 4*num_units)])
    rnn_w0 = mtf.get_variable(lstm0_mesh, 'rnn_w0', w_shape)
    rnn_w1 = mtf.get_variable(lstm1_mesh, 'rnn_w1', w_shape)

    # RNN initial states
    h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(k_dim_name, num_units)])
    c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(n_dim_name, num_units)])
    states0 = [mtf.zeros(lstm0_mesh, h_shape), mtf.zeros(lstm0_mesh, c_shape)]
    states1 = [mtf.zeros(lstm1_mesh, h_shape), mtf.zeros(lstm1_mesh, c_shape)]

    # Model - embedding
    vocab_dim = mtf.Dimension(k_dim_name, params.vocab_size)
    embed_dim = mtf.Dimension(n_dim_name, params.num_units)
    assert mtf_inputs.mesh == embed_mesh
    embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim,
            tf.float32)
    assert embedding.shape[-1].name == n_dim_name
    shape = embedding.shape.rename_dimension(n_dim_name, k_dim_name)
    embedding = mesh_trans.ReplaceMeshWithIndependentAxes(
            embedding, lstm0_mesh, shape.dimension_names)

    # Model - RNN
    [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units,
            states=states0 + states1).outputs
    assert y.mesh == lstm1_mesh
    assert y.shape[-1].name == k_dim_name
    assert mesh_to_impl[proj_mesh].shape[-1] == mtf.Dimension(k_dim_name, 1)
    rand_dim_name = utils.RandName()
    y = mt.rename_dimension(y, k_dim_name, rand_dim_name)
    shape = y.shape.rename_dimension(rand_dim_name, k_dim_name)
    y = mesh_trans.ReplaceMeshWithIndependentAxes(
            y, proj_mesh, shape.dimension_names)

    # Model - Dense + loss
    assert y.shape[-1].name == k_dim_name
    vocab_dim = mtf.Dimension(n_dim_name, params.vocab_size)
    y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:],
            use_bias=False)
    assert mtf_labels.mesh == proj_mesh
    mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            y, mtf_labels, vocab_dim)
    mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    model.soft_placement = True
    return graph, mesh_to_impl, mtf_loss
Example #2
0
def model(params, inputs, labels):
    # Mtf mesh
    assert len(inputs.shape) == 2
    graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes(
            inputs, labels, params.num_nodes, params.num_gpus,
            params.batch_size)

    # Embedding dimensions
    vocab_dim = mtf.Dimension(utils.RandName(), params.vocab_size)
    embed_dim = mtf.Dimension(utils.RandName(), params.num_units)

    batch_dim_name = mtf_inputs.shape[0].name
    k_dim_name = embed_dim.name
    n_dim_name = utils.RandName()

    # RNN weights
    num_units = params.num_units
    w_shape = utils.ConvertToShape(
            [(k_dim_name, 2*num_units), (n_dim_name, 4*num_units)])
    rnn_w0 = mtf.get_variable(meshes[0], 'rnn_w0', w_shape)
    rnn_w1 = mtf.get_variable(meshes[1], 'rnn_w1', w_shape)

    # RNN initial states
    h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(k_dim_name, num_units)])
    c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(n_dim_name, num_units)])
    states0 = [mtf.zeros(meshes[0], h_shape), mtf.zeros(meshes[0], c_shape)]
    states1 = [mtf.zeros(meshes[1], h_shape), mtf.zeros(meshes[1], c_shape)]

    # Model
    embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim,
            tf.float32)
    assert embedding.mesh == meshes[2]
    embedding = ReplaceRNNMesh(embedding, meshes[0]).outputs[0]

    [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units,
            states=states0+states1).outputs
    assert y.mesh == meshes[1]
    assert y.shape[0].name == 'axis0'
    y = mt.rename_dimension(y, 'axis0', mtf_labels.shape[0].name)
    y = mesh_trans.ReplaceMeshWithSimpleReplication(y, meshes[2])

    vocab_dim = mtf.Dimension('axis0', params.vocab_size)
    y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:],
            use_bias=False)
    assert y.mesh == mtf_labels.mesh
    mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            y, mtf_labels, vocab_dim)
    mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    model.soft_placement = True
    return graph, mesh_to_impl, mtf_loss
Example #3
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    mesh = mtf.Mesh(graph, 'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus], 
            gpus_per_node=num_gpus // num_nodes)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(mesh, inputs, shape)
    mtf_labels = mtf.import_tf_tensor(mesh, labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
Example #4
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    assert num_gpus % num_nodes == 0
    assert num_gpus % 2 == 0
    gpus_per_node = num_gpus // num_nodes
    devices = utils.GetDeviceList(num_gpus, gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh0')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[:num_gpus//2], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh1')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus//2],
            devices=devices[num_gpus//2:], gpus_per_node=gpus_per_node)

    mesh = mtf.Mesh(graph, f'mesh2')
    meshes.append(mesh)
    mesh_to_impl[mesh] = utils.GetMeshImpl([num_gpus],
            devices=utils.FlattenList(utils.TransposeLists(
                [devices[:num_gpus//2], devices[num_gpus//2:]])),
            gpus_per_node=gpus_per_node)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    shape = utils.ConvertToShape([('axis0', batch_size),
        inputs.shape.as_list()[1]])
    mtf_inputs = mtf.import_tf_tensor(meshes[2], inputs, shape)
    shape = shape.rename_dimension('axis0', utils.RandName())
    mtf_labels = mtf.import_tf_tensor(meshes[2], labels, shape)

    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
Example #5
0
def CreateMeshes(inputs, labels, num_nodes, num_gpus, batch_size):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}
    gpus_per_node = num_gpus // num_nodes
    devices = utils.GetDeviceList(num_gpus, gpus_per_node)

    assert len(inputs.shape) == 2
    assert inputs.shape == labels.shape

    if num_gpus == 4:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 4],
                       [2, 1, 1],
                       [2, 1, 1],
                       [1, 4, 1]]

    elif num_gpus == 8:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 8],
                       [4, 1, 1],
                       [4, 1, 1],
                       [1, 8, 1]]

    elif num_gpus == 16:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 16],
                       [2, 2, 2],
                       [2, 2, 2],
                       [1, 16, 1]]

    elif num_gpus == 32:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 32],
                       [4, 2, 2],
                       [4, 2, 2],
                       [1, 32, 1]]

    elif num_gpus == 64:
        # Mesh_shape: batch_dim, n_dim, k_dim
        mesh_shapes = [[1, 1, 64],
                       [8, 2, 2],
                       [8, 2, 2],
                       [1, 64, 1]]

    else:
        assert False

    assert mesh_shapes[1] == mesh_shapes[2]
    assert (utils.Prod(mesh_shapes[1]) == utils.Prod(mesh_shapes[2]) ==
            num_gpus // 2)
    assert (num_nodes == 1) or (num_nodes % 2 == 0)
    half_devices0 = devices[:(num_gpus // 2)]
    half_devices1 = devices[(num_gpus // 2):]
    mesh_devices = [devices, half_devices0, half_devices1,
            half_devices1 + half_devices0]

    for i, (mesh_shape, ds) in enumerate(zip(mesh_shapes, mesh_devices)):
        mesh = mtf.Mesh(graph, 'mesh' + str(i))
        meshes.append(mesh)
        mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape, devices=ds,
                gpus_per_node=gpus_per_node)

    mtf_shape = utils.ConvertToShape([('axis0', batch_size)] +
        inputs.shape.as_list()[1:])
    mtf_inputs = mtf.import_tf_tensor(meshes[0], inputs, mtf_shape)
    mtf_labels = mtf.import_tf_tensor(meshes[-1], labels, mtf_shape)
    return graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels
Example #6
0
def CreateMeshes(args, img, labels, num_nodes, num_gpus):
    h, w, ch = 299, 299, 3
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    strategy = args.strategy
    batch_size = args.batch_size
    gpus_per_node = (num_gpus // num_nodes)

    def Mesh():
        mesh = mtf.Mesh(graph, 'mesh%d' % Mesh.idx)
        meshes.append(mesh)
        Mesh.idx += 1
        return mesh

    Mesh.idx = 0

    GetMeshImpl = functools.partial(utils.GetMeshImpl,
                                    gpus_per_node=gpus_per_node)
    if strategy == 0:
        mesh = Mesh()
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mtf_img = mtf.import_tf_tensor(
            mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(
            mesh, labels, utils.ConvertToShape([('axis0', batch_size)]))

    elif strategy == 1:
        # mesh0
        mesh = Mesh()
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        # mesh1
        mesh = Mesh()
        if num_gpus == 4:
            mesh_to_impl[mesh] = GetMeshImpl([4, 1])
        else:
            mesh_to_impl[mesh] = GetMeshImpl([num_gpus // 2, 2])

        mtf_img = mtf.import_tf_tensor(
            meshes[0], img,
            utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(meshes[1], labels,
                                          utils.ConvertToShape([batch_size]))

    elif strategy == 2:
        # mesh0
        mesh = Mesh()
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mtf_img = mtf.import_tf_tensor(
            meshes[0], img,
            utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(meshes[0], labels,
                                          utils.ConvertToShape([batch_size]))

    elif strategy == 3:
        mesh = mtf.Mesh(graph, 'mesh0')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mesh = mtf.Mesh(graph, 'mesh1')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([2, num_gpus // 2])

        mtf_img = mtf.import_tf_tensor(
            meshes[0], img,
            utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(
            meshes[1], labels, utils.ConvertToShape([('axis0', batch_size)]))

    else:
        assert False

    return graph, meshes, mesh_to_impl, mtf_img, mtf_labels
Example #7
0
def CreateMeshes(img, labels, num_nodes, num_gpus, args):
    h, w, ch = 227, 227, 3
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}

    strategy = args.strategy
    batch_size = args.batch_size
    gpus_per_node = (num_gpus // num_nodes)

    GetMeshImpl = functools.partial(utils.GetMeshImpl,
                                    gpus_per_node=gpus_per_node)
    if strategy == 0:
        mesh = mtf.Mesh(graph, 'mesh0')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mtf_img = mtf.import_tf_tensor(
            mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(
            mesh, labels, utils.ConvertToShape([('axis0', batch_size)]))

    elif strategy == 1:
        mesh = mtf.Mesh(graph, 'mesh0')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        if num_gpus == 4:
            dim1, dim2 = 4, 1
        elif num_gpus == 8:
            dim1, dim2 = 4, 2
        elif num_gpus == 16:
            dim1, dim2 = 8, 2
        elif num_gpus == 32:
            dim1, dim2 = 8, 4
        elif num_gpus == 64:
            dim1, dim2 = 8, 8
        else:
            assert False
        assert ((dim1 * dim2) == num_gpus)

        mesh = mtf.Mesh(graph, 'mesh1')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([dim1, dim2])

        mesh = mtf.Mesh(graph, 'mesh2')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([1, dim2])

        mtf_img = mtf.import_tf_tensor(
            meshes[0], img,
            utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(meshes[-1], labels,
                                          utils.ConvertToShape([batch_size]))

    elif strategy == 2:
        mesh = mtf.Mesh(graph, 'mesh0')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mtf_img = mtf.import_tf_tensor(
            mesh, img, utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(mesh, labels,
                                          utils.ConvertToShape([batch_size]))

    elif strategy == 3:
        mesh = mtf.Mesh(graph, 'mesh0')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([num_gpus])

        mesh = mtf.Mesh(graph, 'mesh1')
        meshes.append(mesh)
        mesh_to_impl[mesh] = GetMeshImpl([2, num_gpus // 2])

        mtf_img = mtf.import_tf_tensor(
            meshes[0], img,
            utils.ConvertToShape([('axis0', batch_size), h, w, ch]))
        mtf_labels = mtf.import_tf_tensor(
            meshes[1], labels, utils.ConvertToShape([('axis0', batch_size)]))

    else:
        assert False

    return graph, meshes, mesh_to_impl, mtf_img, mtf_labels
Example #8
0
def CreateMeshes(strategy, src, tgt, num_nodes, num_gpus, params):
    graph = mtf.Graph()
    meshes = []
    mesh_to_impl = {}
    gpus_per_node = num_gpus // num_nodes

    mesh_id = 0

    def CreateMesh(mesh_shape):
        nonlocal mesh_id
        num_nodes = ((utils.Prod(mesh_shape) + gpus_per_node - 1) //
                     gpus_per_node)

        mesh = mtf.Mesh(graph, f'mesh{mesh_id}')
        meshes.append(mesh)
        mesh_id += 1
        mesh_to_impl[mesh] = utils.GetMeshImpl(mesh_shape,
                                               gpus_per_node=gpus_per_node)
        return mesh

    if strategy == 0:  # Data-parallel
        mesh = CreateMesh([num_gpus])
        shape = utils.ConvertToShape([('axis0', params.batch_size),
                                      params.max_seq_len])
        mtf_src = mtf.import_tf_tensor(mesh, src, shape)
        mtf_tgt = mtf.import_tf_tensor(mesh, src, shape)

    elif strategy == 1:  # Opt strategy from the tool
        if num_gpus == 4:
            dim1, dim2 = 4, 1
        elif num_gpus == 8:
            dim1, dim2 = 8, 1
        elif num_gpus == 16:
            if params.model_size == 'small':
                dim1, dim2 = 16, 1
            else:
                dim1, dim2 = 8, 2
        elif num_gpus == 32:
            dim1, dim2 = 16, 2
        elif num_gpus == 64:
            if params.model_size == 'small':
                dim1, dim2 = 32, 2
            else:
                dim1, dim2 = 16, 4
        else:
            assert False
        assert ((dim1 * dim2) == num_gpus)

        mesh0 = CreateMesh([dim1, dim2])
        mesh1 = CreateMesh([num_gpus])

        shape = utils.ConvertToShape([params.batch_size, params.max_seq_len])
        mtf_src = mtf.import_tf_tensor(mesh1, src, shape)
        mtf_tgt = mtf.import_tf_tensor(mesh1, tgt, shape)

    elif strategy == 2:  # Strategy from mesh-tensorflow paper
        if num_gpus == 4:
            dim1, dim2 = 4, 1
        elif num_gpus == 8:
            dim1, dim2 = 2, 4
        elif num_gpus == 16:
            dim1, dim2 = 4, 4
        elif num_gpus == 32:
            dim1, dim2 = 4, 8
        elif num_gpus == 64:
            dim1, dim2 = 8, 8
        else:
            assert False
        assert ((dim1 * dim2) == num_gpus)

        mesh = CreateMesh([dim1, dim2])
        shape = utils.ConvertToShape([('axis0', params.batch_size),
                                      params.max_seq_len])
        mtf_src = mtf.import_tf_tensor(mesh, src, shape)
        mtf_tgt = mtf.import_tf_tensor(mesh, src, shape)

    else:
        assert False

    mtf_src = mtf.cast(mtf_src, tf.int32)
    mtf_tgt = mtf.cast(mtf_tgt, tf.int32)
    return graph, meshes, mesh_to_impl, mtf_src, mtf_tgt