コード例 #1
0
ファイル: rnnlm_opt.py プロジェクト: baidu-research/PaSE
def model(params, inputs, labels):
    # MTF mesh
    assert len(inputs.shape) == 2
    graph, meshes, mesh_to_impl, mtf_inputs, mtf_labels = CreateMeshes(
            inputs, labels, params.num_nodes, params.num_gpus,
            params.batch_size)
    embed_mesh, lstm0_mesh, lstm1_mesh, proj_mesh = meshes
    batch_dim_name, n_dim_name, k_dim_name = 'axis0', 'axis1', 'axis2'

    # RNN weights
    num_units = params.num_units
    w_shape = utils.ConvertToShape([(k_dim_name, 2*num_units),
        (n_dim_name, 4*num_units)])
    rnn_w0 = mtf.get_variable(lstm0_mesh, 'rnn_w0', w_shape)
    rnn_w1 = mtf.get_variable(lstm1_mesh, 'rnn_w1', w_shape)

    # RNN initial states
    h_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(k_dim_name, num_units)])
    c_shape = mtf.Shape([mtf.Dimension(batch_dim_name, params.batch_size),
        mtf.Dimension(n_dim_name, num_units)])
    states0 = [mtf.zeros(lstm0_mesh, h_shape), mtf.zeros(lstm0_mesh, c_shape)]
    states1 = [mtf.zeros(lstm1_mesh, h_shape), mtf.zeros(lstm1_mesh, c_shape)]

    # Model - embedding
    vocab_dim = mtf.Dimension(k_dim_name, params.vocab_size)
    embed_dim = mtf.Dimension(n_dim_name, params.num_units)
    assert mtf_inputs.mesh == embed_mesh
    embedding = mtf.layers.embedding(mtf_inputs, vocab_dim, embed_dim,
            tf.float32)
    assert embedding.shape[-1].name == n_dim_name
    shape = embedding.shape.rename_dimension(n_dim_name, k_dim_name)
    embedding = mesh_trans.ReplaceMeshWithIndependentAxes(
            embedding, lstm0_mesh, shape.dimension_names)

    # Model - RNN
    [y] = RNNOperation(embedding, rnn_w0, rnn_w1, num_units,
            states=states0 + states1).outputs
    assert y.mesh == lstm1_mesh
    assert y.shape[-1].name == k_dim_name
    assert mesh_to_impl[proj_mesh].shape[-1] == mtf.Dimension(k_dim_name, 1)
    rand_dim_name = utils.RandName()
    y = mt.rename_dimension(y, k_dim_name, rand_dim_name)
    shape = y.shape.rename_dimension(rand_dim_name, k_dim_name)
    y = mesh_trans.ReplaceMeshWithIndependentAxes(
            y, proj_mesh, shape.dimension_names)

    # Model - Dense + loss
    assert y.shape[-1].name == k_dim_name
    vocab_dim = mtf.Dimension(n_dim_name, params.vocab_size)
    y = mtf.layers.dense(y, vocab_dim, reduced_dims=y.shape[-1:],
            use_bias=False)
    assert mtf_labels.mesh == proj_mesh
    mtf_cross_ent = mtf.layers.softmax_cross_entropy_with_logits(
            y, mtf_labels, vocab_dim)
    mtf_loss = mtf.reduce_mean(mtf_cross_ent)

    model.soft_placement = True
    return graph, mesh_to_impl, mtf_loss
コード例 #2
0
    def encoder_decoder(inp, enc_out, vocab_dim, name):
        with tf.variable_scope(name):
            # Embedding
            embed = mtf.layers.embedding(inp,
                                         vocab_dim,
                                         model_dim,
                                         tf.float32,
                                         name=f'{name}_embedding')
            if strategy == 1:
                check_distribution(embed, meshes[1], {})
                shape = embed.shape
                shape = shape.rename_dimension(shape[0].name, 'axis0')
                embed = mesh_trans.ReplaceMeshWithIndependentAxes(
                    embed, meshes[0], dim_names=shape.dimension_names)
            check_distribution(embed, meshes[0], {0: 'axis0'})

            # Positional encoding
            x = positional_encoding(embed)
            check_distribution(x, meshes[0], {0: 'axis0'})

            # Encoder/decoder layers
            for i in range(params.nx):
                # Multihead attention
                y = mtf.layers.multihead_attention(x,
                                                   None,
                                                   None,
                                                   d_k_dim,
                                                   heads_dim,
                                                   dropout=0.5,
                                                   name=f'{name}_att_{i}')
                x = add_norm(x, y, name=f'{name}_att_{i}_norm')
                check_distribution(x, meshes[0], {0: 'axis0'})

                if enc_out is not None:
                    y = mtf.layers.multihead_attention(x,
                                                       enc_out,
                                                       None,
                                                       d_k_dim,
                                                       heads_dim,
                                                       dropout=0.5,
                                                       name=f'{name}_att2_{i}')
                    x = add_norm(x, y, name=f'{name}_att2_{i}_norm')
                    check_distribution(x, meshes[0], {0: 'axis0'})

                # Feed forward
                y = mtf.layers.dense_relu_dense(x,
                                                ff_dim,
                                                dropout=0.5,
                                                name=f'{name}_ff_{i}')
                x = add_norm(x, y, name=f'{name}_ff_{i}_norm')
                check_distribution(x, meshes[0], {0: 'axis0'})
            return x
コード例 #3
0
def Transpose1(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]),
                          *shape[2:]])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), RandName(), 'axis0', 'axis1'])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #4
0
def Contract2(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2, 4]), \
            mesh1:GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape)
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1',
                            RandName()])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #5
0
def MoreDevices(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([2]), \
            mesh1:GetMeshImpl([8])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:-1] + [('axis0', shape[-1])])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), 'axis0', RandName(),
         RandName()])
    Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
コード例 #6
0
def DependentAxes(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0: GetMeshImpl([4, 2]), mesh1: GetMeshImpl([4, 2])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape([('axis0', shape[0]), ('axis1', shape[1]),
                          *shape[2:]])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1, [RandName(), 'axis0', 'axis1',
                            RandName()])

    try:
        Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
        assert False  # This run should fail and throw ValueError
    except ValueError:
        return
コード例 #7
0
def WrongShape(in_tsr):
    graph = mtf.Graph()
    mesh0 = mtf.Mesh(graph, 'mesh0')
    mesh1 = mtf.Mesh(graph, 'mesh1')
    mesh_to_impl = {mesh0:GetMeshImpl([4, 2]), \
            mesh1:GetMeshImpl([8])}

    shape = in_tsr.get_shape().as_list()
    mtf_shape = GetShape(shape[:-2] +
                         [('axis0', shape[2]), ('axis0', shape[3])])
    mtf_in_tsr = mtf.import_tf_tensor(mesh0, in_tsr, mtf_shape)
    mtf_out_tsr = mt.ReplaceMeshWithIndependentAxes(
        mtf_in_tsr, mesh1,
        [RandName(), 'axis0', RandName(),
         RandName()])

    try:
        Run(graph, mesh_to_impl, in_tsr, mtf_out_tsr)
        assert False  # This test should fail with ValueError
    except ValueError:
        return
コード例 #8
0
def Transformer(src, tgt, params, src_vocab_size, tgt_vocab_size, strategy,
                num_nodes, num_gpus):
    graph, meshes, mesh_to_impl, mtf_src, mtf_tgt = CreateMeshes(
        strategy, src, tgt, num_nodes, num_gpus, params)
    src_vocab_size = utils.RoundUp(src_vocab_size, num_gpus)
    tgt_vocab_size = utils.RoundUp(tgt_vocab_size, num_gpus)

    # mtf dimensions
    if strategy == 0:
        src_vocab_dim = mtf.Dimension(RandName(), src_vocab_size)
        tgt_vocab_dim = mtf.Dimension(RandName(), tgt_vocab_size)
        model_dim = mtf.Dimension(RandName(), params.d_model)
        d_k_dim = mtf.Dimension(RandName(), params.d_k)
        heads_dim = mtf.Dimension(RandName(), params.heads)
        ff_dim = mtf.Dimension(RandName(), params.d_ff)
    elif strategy == 1:
        src_vocab_dim = mtf.Dimension('axis0', src_vocab_size)
        tgt_vocab_dim = mtf.Dimension('axis0', tgt_vocab_size)
        model_dim = mtf.Dimension(RandName(), params.d_model)
        d_k_dim = mtf.Dimension(RandName(), params.d_k)
        heads_dim = mtf.Dimension('axis1', params.heads)
        ff_dim = mtf.Dimension('axis1', params.d_ff)
    elif strategy == 2:
        src_vocab_dim = mtf.Dimension('axis1', src_vocab_size)
        tgt_vocab_dim = mtf.Dimension('axis1', tgt_vocab_size)
        model_dim = mtf.Dimension(RandName(), params.d_model)
        d_k_dim = mtf.Dimension(RandName(), params.d_k)
        heads_dim = mtf.Dimension('axis1', params.heads)
        ff_dim = mtf.Dimension('axis1', params.d_ff)
    else:
        assert False
    seq_len_dim = mtf_src.shape[-1]
    assert mtf_src.shape[-1] == mtf_tgt.shape[-1]

    if strategy == 1:
        check_distribution(mtf_src, meshes[1], {})
        check_distribution(mtf_tgt, meshes[1], {})
    else:
        check_distribution(mtf_src, meshes[0], {0: 'axis0'})
        check_distribution(mtf_tgt, meshes[0], {0: 'axis0'})

    def encoder_decoder(inp, enc_out, vocab_dim, name):
        with tf.variable_scope(name):
            # Embedding
            embed = mtf.layers.embedding(inp,
                                         vocab_dim,
                                         model_dim,
                                         tf.float32,
                                         name=f'{name}_embedding')
            if strategy == 1:
                check_distribution(embed, meshes[1], {})
                shape = embed.shape
                shape = shape.rename_dimension(shape[0].name, 'axis0')
                embed = mesh_trans.ReplaceMeshWithIndependentAxes(
                    embed, meshes[0], dim_names=shape.dimension_names)
            check_distribution(embed, meshes[0], {0: 'axis0'})

            # Positional encoding
            x = positional_encoding(embed)
            check_distribution(x, meshes[0], {0: 'axis0'})

            # Encoder/decoder layers
            for i in range(params.nx):
                # Multihead attention
                y = mtf.layers.multihead_attention(x,
                                                   None,
                                                   None,
                                                   d_k_dim,
                                                   heads_dim,
                                                   dropout=0.5,
                                                   name=f'{name}_att_{i}')
                x = add_norm(x, y, name=f'{name}_att_{i}_norm')
                check_distribution(x, meshes[0], {0: 'axis0'})

                if enc_out is not None:
                    y = mtf.layers.multihead_attention(x,
                                                       enc_out,
                                                       None,
                                                       d_k_dim,
                                                       heads_dim,
                                                       dropout=0.5,
                                                       name=f'{name}_att2_{i}')
                    x = add_norm(x, y, name=f'{name}_att2_{i}_norm')
                    check_distribution(x, meshes[0], {0: 'axis0'})

                # Feed forward
                y = mtf.layers.dense_relu_dense(x,
                                                ff_dim,
                                                dropout=0.5,
                                                name=f'{name}_ff_{i}')
                x = add_norm(x, y, name=f'{name}_ff_{i}_norm')
                check_distribution(x, meshes[0], {0: 'axis0'})
            return x

    # Encoder/Decoder
    enc_out = encoder_decoder(mtf_src, None, src_vocab_dim, 'encoder')
    check_distribution(enc_out, meshes[0], {0: 'axis0'})
    enc_out = mt.rename_dimension(enc_out, enc_out.shape[1].name, RandName())
    dec_out = encoder_decoder(mtf_tgt, enc_out, tgt_vocab_dim, 'decoder')

    # Loss function
    with tf.variable_scope('loss'):
        check_distribution(dec_out, meshes[0], {0: 'axis0'})
        if strategy == 1:
            shape = dec_out.shape.rename_dimension('axis0',
                                                   mtf_tgt.shape[0].name)
            dec_out = mesh_trans.ReplaceMeshWithIndependentAxes(
                dec_out, meshes[1], dim_names=shape.dimension_names)
            check_distribution(dec_out, meshes[1], {})

        out = mtf.layers.dense(dec_out,
                               tgt_vocab_dim,
                               use_bias=False,
                               reduced_dims=dec_out.shape[-1:],
                               name='final_projection')
        assert out.mesh == mtf_tgt.mesh
        assert (out.shape.dims == mtf_tgt.shape.dims + [tgt_vocab_dim])
        out = mtf.layers.softmax_cross_entropy_with_logits(
            out, mtf_tgt, tgt_vocab_dim)
        loss = mtf.reduce_mean(out)

    return graph, mesh_to_impl, loss