Exemplo n.º 1
0
def linear(x, dim, bias=True, ln=False,
           weight_initializer=None,
           bias_initializer=None,
           scope=None):
    """
    basic linear or feed forward layer
    :param x: input tensor or list
    :param dim: output dimension or list
    :param bias: whether use bias term
    :param ln: whether use layer normalization
    :param weight_initializer: you can set it if you want
    :param bias_initializer: you can set it if you want
    :param scope
    :return:
    """
    with tf.variable_scope(scope or "linear", values=[x]):
        if not isinstance(x, (list, tuple)):
            x = [x]
        if not isinstance(dim, (list, tuple)):
            dim = [dim]

        if not ln:
            # by default, we concatenate inputs
            x = [tf.concat(x, -1)]

        outputs = []
        for oidx, osize in enumerate(dim):

            results = []
            for iidx, ix in enumerate(x):
                x_shp = util.shape_list(ix)
                xsize = x_shp[-1]

                W = tf.get_variable(
                    "W_{}_{}".format(oidx, iidx), [xsize, osize],
                    initializer=weight_initializer)
                o = tf.matmul(tf.reshape(ix, [-1, xsize]), W)

                if ln:
                    o = layer_norm(
                        o, scope="ln_{}_{}".format(oidx, iidx))
                results.append(o)

            o = tf.add_n(results)

            if bias:
                b = tf.get_variable(
                    "b_{}".format(oidx), [osize],
                    initializer=bias_initializer)
                o = tf.nn.bias_add(o, b)
            x_shp = util.shape_list(x[0])[:-1]
            o = tf.reshape(o, tf.concat([x_shp, [osize]], 0))

            outputs.append(o)

        if len(outputs) == 1:
            return outputs[0]
        else:
            return outputs
Exemplo n.º 2
0
def depthwise_conv(inputs,
                   hidden_size,
                   kernel_size=1,
                   bias=True,
                   activation=None,
                   scope='depthwise_conv'):
    with tf.variable_scope(scope or "depthwise_conv"):
        shapes = util.shape_list(inputs)
        depthwise_filter = tf.get_variable('depthwise_filter',
                                           (kernel_size, 1, shapes[-1], 1))
        pointwise_filter = tf.get_variable('pointwise_filter',
                                           (1, 1, shapes[-1], hidden_size))

        outputs = tf.nn.separable_conv2d(inputs,
                                         depthwise_filter,
                                         pointwise_filter,
                                         strides=(1, 1, 1, 1),
                                         padding='SAME')
        if bias:
            b = tf.get_variable('bias',
                                outputs.shape[-1],
                                initializer=tf.zeros_initializer())
            outputs += b

        if activation is not None:
            return activation(outputs)
        else:
            return outputs
Exemplo n.º 3
0
def graph(features, params):
    if params.enable_bert:
        s = features['s']

        bert_input = s
        sequence_output = bert.bert_encoder(bert_input, params)
        s_enc = tf.concat(sequence_output[0][-4:], -1)[:, 1:, :]

        sb = features['sb']
        sb_shp = util.shape_list(sb)

        s_coord = tf.stack([util.batch_coordinates(sb_shp[0], sb_shp[1]), sb],
                           axis=2)
        s_enc = tf.gather_nd(s_enc, s_coord)

        features['bert_enc'] = util.valid_apply_dropout(s_enc, params.dropout)
        if not params.use_bert_single:
            features['feature'] = s_enc[:, 0, :]
        else:
            features['feature'] = sequence_output[1]

    features = embedding_layer(features, params)
    features = hierarchy_layer(features, params)
    graph_output = loss_layer(features, params)

    return graph_output
Exemplo n.º 4
0
def extract_encodes(source_memory, source_mask, l0_mask):
    x_shp = util.shape_list(source_memory)

    l0_mask = dtype.tf_to_float(tf.cast(l0_mask, tf.bool))
    l0_mask = tf.squeeze(l0_mask, -1) * source_mask

    # count retained encodings
    k_value = tf.cast(tf.reduce_max(tf.reduce_sum(l0_mask, 1)), tf.int32)
    # batch_size x k_value
    _, topk_indices = tf.nn.top_k(l0_mask, k_value)

    # prepare coordinate
    x_pos = util.batch_coordinates(x_shp[0], k_value)
    coord = tf.stack([x_pos, topk_indices], axis=2)

    # gather retained features
    g_x = tf.gather_nd(source_memory, coord)
    g_mask = tf.gather_nd(l0_mask, coord)

    # padding zero
    g_x = tf.pad(g_x, [[0, 0], [1, 0], [0, 0]])

    # generate counts, i.e. how many tokens are dropped
    droped_number = tf.reduce_sum(source_mask, 1) - tf.reduce_sum(l0_mask, 1)
    pad_mask = dtype.tf_to_float(tf.greater(droped_number, 0.))
    droped_number = tf.where(tf.less_equal(droped_number, 0.),
                             tf.ones_like(droped_number), droped_number)

    count_mask = tf.ones_like(g_mask)
    count_mask = tf.concat([tf.expand_dims(droped_number, 1), count_mask], 1)

    g_mask = tf.concat([tf.expand_dims(pad_mask, 1), g_mask], 1)

    return g_x, g_mask, count_mask
Exemplo n.º 5
0
def graph(features, params):
    if params.enable_bert:
        ps = features['ps']
        hs = features['hs']

        bert_input = tf.concat([ps, hs], 1)
        sequence_output = bert.bert_encoder(bert_input, params)
        sequence_feature = bert.bert_feature(sequence_output[0])

        p_len = tf.shape(ps)[1]
        # 1: remove the encoding for `cls`
        p_enc = sequence_feature[:, 1:p_len, :]
        h_enc = sequence_feature[:, p_len:, :]

        pb = features['pb']
        hb = features['hb']

        pb_shp = util.shape_list(pb)
        hb_shp = util.shape_list(hb)

        p_coord = tf.stack(
            [util.batch_coordinates(pb_shp[0], pb_shp[1]), pb],
            axis=2
        )
        p_enc = tf.gather_nd(p_enc, p_coord)

        h_coord = tf.stack(
            [util.batch_coordinates(hb_shp[0], hb_shp[1]), hb],
            axis=2
        )
        h_enc = tf.gather_nd(h_enc, h_coord)

        features['bert_p_enc'] = util.valid_apply_dropout(p_enc, params.dropout)
        features['bert_h_enc'] = util.valid_apply_dropout(h_enc, params.dropout)
        if not params.use_bert_single:
            features['feature'] = sequence_feature[:, 0, :]
        else:
            features['feature'] = sequence_output[1]

    features = embedding_layer(features, params)
    features = match_layer(features, params)
    features = loss_layer(features, params)

    return features
Exemplo n.º 6
0
    def mlceloss(logits, labels):
        soft_label, normalizer = util.label_smooth(labels,
                                                   util.shape_list(logits)[-1],
                                                   factor=params.label_smooth)
        centropy = tf.nn.softmax_cross_entropy_with_logits_v2(
            logits=logits, labels=soft_label)
        centropy -= normalizer
        centropy = tf.reshape(centropy, tf.shape(labels))

        return tf.reduce_mean(centropy)
Exemplo n.º 7
0
def wrap_rnn(x,
             cell_type,
             nlayers,
             hidden_size,
             mask=None,
             bidir=True,
             use_ln=True,
             concat=True,
             dropout=0.0,
             scope=None):
    outputs = [x]
    states = []

    if mask is None:
        xshp = util.shape_list(x)
        mask = tf.ones([xshp[0], xshp[1]], tf.float32)

    for layer in range(nlayers):
        with tf.variable_scope("{}_layer_{}".format(scope or 'rnn', layer)):
            with tf.variable_scope("fw_rnn"):
                _, (o_fw, o_fw_s) = rnn.rnn(cell_type,
                                            outputs[-1],
                                            hidden_size,
                                            mask=mask,
                                            ln=use_ln,
                                            sm=False)
            if bidir:
                with tf.variable_scope("bw_rnn"):
                    _, (o_bw, o_bw_s) = rnn.rnn(cell_type,
                                                tf.reverse(outputs[-1], [1]),
                                                hidden_size,
                                                mask=tf.reverse(mask, [1]),
                                                ln=use_ln,
                                                sm=False)
                    o_bw = tf.reverse(o_bw, [1])

            if layer != nlayers - 1:
                o_fw = util.valid_apply_dropout(o_fw, dropout)
                o_fw_s = util.valid_apply_dropout(o_fw_s, dropout)

                if bidir:
                    o_bw = util.valid_apply_dropout(o_bw, dropout)
                    o_bw_s = util.valid_apply_dropout(o_bw_s, dropout)

            if not bidir:
                outputs.append(o_fw)
                states.append(o_fw_s)
            else:
                outputs.append(tf.concat([o_fw, o_bw], -1))
                states.append(tf.concat([o_fw_s, o_bw_s], -1))

    if concat:
        return tf.concat(outputs[1:], -1), tf.concat(states, -1)
    else:
        return outputs[-1], states[-1]
Exemplo n.º 8
0
def rnn(cell_name, x, d, mask=None, ln=False, init_state=None, sm=True, dp=0.0):
    """Self implemented RNN procedure, supporting mask trick"""
    # cell_name: gru, lstm or atr
    # x: input sequence embedding matrix, [batch, seq_len, dim]
    # d: hidden dimension for rnn
    # mask: mask matrix, [batch, seq_len]
    # ln: whether use layer normalization
    # init_state: the initial hidden states, for cache purpose
    # sm: whether apply swap memory during rnn scan
    # dp: variational dropout

    in_shape = util.shape_list(x)
    batch_size, time_steps = in_shape[:2]

    cell = get_cell(cell_name, d, ln=ln)

    if init_state is None:
        init_state = cell.get_init_state(shape=[batch_size])
    if mask is None:
        mask = tf.ones([batch_size, time_steps], tf.float32)

    # prepare projected input
    cache_inputs = cell.fetch_states(x)
    cache_inputs = [tf.transpose(v, [1, 0, 2])
                    for v in list(cache_inputs)]
    mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2])

    def _step_fn(prev, x):
        t, h_ = prev
        m = x[-1]
        v = x[:-1]

        h = cell(h_, v)
        h = m * h + (1. - m) * h_

        return t + 1, h

    time = tf.constant(0, dtype=tf.int32, name="time")
    step_states = (time, init_state)
    step_vars = cache_inputs + [mask_ta]

    outputs = tf.scan(_step_fn,
                      step_vars,
                      initializer=step_states,
                      parallel_iterations=32,
                      swap_memory=sm)

    output_ta = outputs[1]
    output_state = outputs[1][-1]

    outputs = tf.transpose(output_ta, [1, 0, 2])

    return (outputs, output_state), \
           (cell.get_hidden(outputs), cell.get_hidden(output_state))
Exemplo n.º 9
0
def trilinear_similarity(x1, x2, scope='trilinear'):
    with tf.variable_scope(scope or "trilinear"):
        x1_shape = util.shape_list(x1)
        x2_shape = util.shape_list(x2)

        if len(x1_shape) != 3 or len(x2_shape) != 3:
            raise ValueError(
                '`args` must be 3 dims (batch_size, len, dimension)')
        if x1_shape[2] != x2_shape[2]:
            raise ValueError('the last dimension of `args` must equal')

        w1 = tf.get_variable('kernel_x1', [x1_shape[2], 1])
        w2 = tf.get_variable('kernel_x2', [x2_shape[2], 1])
        w3 = tf.get_variable('kernel_mul', [1, 1, x1_shape[2]])
        bias = tf.get_variable('bias', [1], initializer=tf.zeros_initializer())

        r1 = tf.einsum('aij,jk->aik', x1, w1)
        r2 = tf.einsum('aij,jk->aki', x2, w2)
        r3 = tf.einsum('aij,akj->aik', x1 * w3, x2)
        return r1 + r2 + r3 + bias
Exemplo n.º 10
0
def rms_norm(x, eps=None, scope=None):
    """RMS-based Layer normalization layer"""
    if eps is None:
        eps = dtype.epsilon()
    with tf.variable_scope(scope or "rms_norm",
                           dtype=tf.as_dtype(dtype.floatx())):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())

        ms = tf.reduce_mean(x ** 2, -1, keep_dims=True)

        return scale * x * tf.rsqrt(ms + eps)
Exemplo n.º 11
0
def layer_norm(x, eps=1e-8, scope=None):
    """RMS-based Layer normalization layer
    https://openreview.net/pdf?id=SygkZ3MTJE
    """
    with tf.variable_scope(scope or "rms_norm"):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size],
                                initializer=tf.ones_initializer())

        ms = tf.reduce_mean(x**2, -1, keep_dims=True)

        return scale * x * tf.rsqrt(ms + eps)
Exemplo n.º 12
0
def layer_norm(x, eps=1e-8, scope=None):
    """Layer normalization layer"""
    with tf.variable_scope(scope or "layer_norm"):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size],
                                initializer=tf.ones_initializer())
        offset = tf.get_variable("offset", [layer_size],
                                 initializer=tf.zeros_initializer())

        mean = tf.reduce_mean(x, -1, keep_dims=True)
        var = tf.reduce_mean((x - mean) ** 2, -1, keep_dims=True)

        return scale * (x - mean) * tf.rsqrt(var + eps) + offset
Exemplo n.º 13
0
def gated_rms_norm(x, eps=None, scope=None):
    """RMS-based Layer normalization layer"""
    if eps is None:
        eps = dtype.epsilon()
    with tf.variable_scope(scope or "rms_norm",
                           dtype=tf.as_dtype(dtype.floatx())):
        layer_size = util.shape_list(x)[-1]

        scale = tf.get_variable("scale", [layer_size], initializer=tf.ones_initializer())
        gate = tf.get_variable("gate", [layer_size], initializer=None)

        ms = tf.reduce_mean(x ** 2, -1, keep_dims=True)

        # adding gating here which slightly improves quality
        return scale * x * tf.rsqrt(ms + eps) * tf.nn.sigmoid(gate * x)
Exemplo n.º 14
0
def shard_features(features, num_devices):
    """Split features into several shards according to the given device list
    :param features: a dictionary containing input datas
    :param num_devices: gpu device number
    """
    num_datashards = num_devices

    sharded_features = {}
    pieces = util.uniform_splits(tf.shape(features.values()[0])[0],
                                 num_datashards)
    device_mask = tf.to_float(tf.greater(pieces, 0))

    # why tile it?
    # because the piece can be 0-shaped.
    # feeding an empty input to the model may be problematic.
    tile_size = tf.cond(tf.reduce_any(tf.equal(device_mask, 0.0)),
                        lambda: num_datashards,
                        lambda: 1)
    tile_pieces = util.uniform_splits(
        tf.shape(features.values()[0])[0] * tile_size,
        num_datashards)

    for k, v in features.iteritems():
        v = tf.convert_to_tensor(v)
        if not v.shape.as_list():
            v = tf.expand_dims(v, axis=-1)
            v = tf.tile(v, [tf.reduce_sum(tile_pieces)])
        else:
            # to avoid the empty data input
            v_shp = util.shape_list(v)
            t_shp = [1] * len(v_shp)
            t_shp[0] = tile_size
            v = tf.tile(v, t_shp)
        with tf.device(v.device):
            sharded_features[k] = tf.split(v, tile_pieces, 0)

    datashard_to_features = []

    for d in range(num_datashards):
        feat = {
            k: v[d] for k, v in sharded_features.items()
        }
        datashard_to_features.append(feat)

    return datashard_to_features, device_mask
Exemplo n.º 15
0
def cnn(inputs, hidden_size, mask=None, scope="cnn"):
    with tf.variable_scope(scope or "cnn"):
        ishp = util.shape_list(inputs)
        if mask is None:
            mask = tf.ones([ishp[0], ishp[1]])

        x = inputs
        x = x * tf.expand_dims(mask, -1)

        x0 = tf.pad(x, [[0, 0], [1, 0], [0, 0]])[:, :-1, :]
        x1 = tf.pad(x, [[0, 0], [0, 1], [0, 0]])[:, 1:, :]

        y = tf.concat([x0, x, x1], -1)
        y = linear(y, hidden_size * 2, ln=False, scope="ff")

        A = y[:, :, :hidden_size]
        B = y[:, :, hidden_size:]
        y = A * tf.sigmoid(B)

        y += x

        return layer_norm(y, scope="ln")
Exemplo n.º 16
0
def embedding_layer(features, params):
    p = features['p']
    h = features['h']

    p_mask = tf.to_float(tf.cast(p, tf.bool))
    h_mask = tf.to_float(tf.cast(h, tf.bool))

    with tf.device('/cpu:0'):
        symbol_embeddings = tf.get_variable('special_symbol_embeddings',
                                            shape=(3, params.embed_size),
                                            trainable=True)
        embedding_initializer = tf.glorot_uniform_initializer()
        if tf.gfile.Exists(params.pretrain_word_embedding_file):
            pretrain_embedding = np.load(params.pretrain_word_embedding_file)['data']
            embedding_initializer = tf.constant_initializer(pretrain_embedding)
        general_embeddings = tf.get_variable('general_symbol_embeddings',
                                             shape=(params.word_vocab.size() - 3, params.embed_size),
                                             initializer=embedding_initializer,
                                             trainable=False)
        word_embeddings = tf.concat([symbol_embeddings, general_embeddings], 0)

        p_emb = tf.nn.embedding_lookup(word_embeddings, p)
        h_emb = tf.nn.embedding_lookup(word_embeddings, h)

    p_features = [p_emb]
    h_features = [h_emb]

    if params.enable_bert:
        p_features.append(features['bert_p_enc'])
        h_features.append(features['bert_h_enc'])

    if params.use_char:
        pc = features['pc']
        hc = features['hc']

        pc_mask = tf.to_float(tf.cast(pc, tf.bool))
        hc_mask = tf.to_float(tf.cast(hc, tf.bool))

        pc = tf.reshape(pc, [-1, tf.shape(pc)[-1]])
        hc = tf.reshape(hc, [-1, tf.shape(hc)[-1]])
        pc_mask = tf.reshape(pc_mask, [-1, tf.shape(pc_mask)[-1]])
        hc_mask = tf.reshape(hc_mask, [-1, tf.shape(hc_mask)[-1]])
        with tf.device('/cpu:0'):
            char_embeddings = tf.get_variable('char_embeddings',
                                              shape=(params.char_vocab.size(), params.char_embed_size),
                                              initializer=tf.glorot_uniform_initializer(),
                                              trainable=True)
            with tf.variable_scope('char_embedding'):
                pc_emb = tf.nn.embedding_lookup(char_embeddings, pc)
                hc_emb = tf.nn.embedding_lookup(char_embeddings, hc)
                if util.valid_dropout(params.dropout):
                    pc_emb = tf.nn.dropout(pc_emb, 1. - 0.5 * params.dropout)
                    hc_emb = tf.nn.dropout(hc_emb, 1. - 0.5 * params.dropout)

        with tf.variable_scope("char_encoding", reuse=tf.AUTO_REUSE):
            pc_emb = pc_emb * tf.expand_dims(pc_mask, -1)
            hc_emb = hc_emb * tf.expand_dims(hc_mask, -1)

            pc_shp = util.shape_list(features['pc'])
            pc_emb = tf.reshape(pc_emb, [pc_shp[0], pc_shp[1], pc_shp[2], params.char_embed_size])
            hc_shp = util.shape_list(features['hc'])
            hc_emb = tf.reshape(hc_emb, [hc_shp[0], hc_shp[1], hc_shp[2], params.char_embed_size])

            pc_state = func.linear(tf.reduce_max(pc_emb, 2), params.char_embed_size, scope="cmap")
            hc_state = func.linear(tf.reduce_max(hc_emb, 2), params.char_embed_size, scope="cmap")

        p_features.append(pc_state)
        h_features.append(hc_state)

    '''
    p_emb = func.highway(tf.concat(p_features, axis=2),
                         size=params.hidden_size, dropout=params.dropout, num_layers=2, scope='highway')
    h_emb = func.highway(tf.concat(h_features, axis=2),
                         size=params.hidden_size, dropout=params.dropout, num_layers=2, scope='highway')
    '''
    p_emb = tf.concat(p_features, axis=2)
    h_emb = tf.concat(h_features, axis=2)

    p_emb = p_emb * tf.expand_dims(p_mask, -1)
    h_emb = h_emb * tf.expand_dims(h_mask, -1)

    features.update({'p_emb': p_emb,
                     'h_emb': h_emb,
                     'p_mask': p_mask,
                     'h_mask': h_mask,
                     })
    return features
Exemplo n.º 17
0
def decoder(target, state, params):
    mask = tf.to_float(tf.cast(target, tf.bool))
    hidden_size = params.hidden_size

    if 'decoder' not in state:
        target, mask = util.remove_invalid_seq(target, mask)

    embed_name = "embedding" if params.shared_source_target_embedding \
        else "tgt_embedding"
    tgt_emb = tf.get_variable(embed_name,
                              [params.tgt_vocab.size(), params.embed_size])
    tgt_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(tgt_emb, target)
    inputs = tf.nn.bias_add(inputs, tgt_bias)

    # shift
    if 'decoder' not in state:
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
        inputs = inputs[:, :-1, :]
    else:
        inputs = tf.cond(tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
                         lambda: tf.zeros_like(inputs),
                         lambda: inputs)
        mask = tf.ones_like(mask)

    if util.valid_dropout(params.dropout):
        inputs = tf.nn.dropout(inputs, 1. - params.dropout)

    with tf.variable_scope("decoder"):
        init_state = state["decoder_initializer"]
        if 'decoder' in state:
            init_state = state["decoder"]["state"]
        returns = rnn.cond_rnn(params.cell, inputs, state["encodes"], hidden_size,
                               init_state=init_state, mask=mask,
                               mem_mask=state["mask"], ln=params.layer_norm,
                               sm=params.swap_memory, one2one=False)
        (hidden_states, _), (outputs, _), contexts, attentions = returns

    feature = linear([outputs, contexts, inputs], params.embed_size,
                     ln=params.layer_norm, scope="pre_logits")
    feature = tf.tanh(feature)
    if util.valid_dropout(params.dropout):
        feature = tf.nn.dropout(feature, 1. - params.dropout)

    embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
        else "softmax_embedding"
    embed_name = "embedding" if params.shared_source_target_embedding \
        else embed_name
    softmax_emb = tf.get_variable(embed_name,
                                  [params.tgt_vocab.size(), params.embed_size])
    feature = tf.reshape(feature, [-1, params.embed_size])
    logits = tf.matmul(feature, softmax_emb, False, True)

    centropy = tf.nn.softmax_cross_entropy_with_logits(
        logits=logits,
        labels=util.label_smooth(target,
                                 util.shape_list(logits)[-1],
                                 factor=params.label_smooth)
    )
    centropy = tf.reshape(centropy, tf.shape(target))

    loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1)
    loss = tf.reduce_mean(loss)

    # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
    loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
                   lambda: tf.constant(0, dtype=tf.float32),
                   lambda: loss)

    if 'decoder' in state:
        state['decoder']['state'] = hidden_states

    return loss, logits, state
Exemplo n.º 18
0
def embedding_layer(features, params):
    t = features['t']

    t_mask = tf.to_float(tf.cast(t, tf.bool))

    with tf.device('/cpu:0'):
        symbol_embeddings = tf.get_variable('special_symbol_embeddings',
                                            shape=(3, params.embed_size),
                                            trainable=True)
        embedding_initializer = tf.glorot_uniform_initializer()
        if params.word_vocab.pretrained_embedding is not None:
            pretrain_embedding = params.word_vocab.pretrained_embedding
            embedding_initializer = tf.constant_initializer(pretrain_embedding)
        general_embeddings = tf.get_variable(
            'general_symbol_embeddings',
            shape=(params.word_vocab.size() - 3, params.embed_size),
            initializer=embedding_initializer,
            trainable=params.word_vocab.pretrained_embedding is None)
        word_embeddings = tf.concat([symbol_embeddings, general_embeddings], 0)

        # apply word dropout
        wd_mask = util.valid_apply_dropout(t_mask, params.word_dropout)
        wd_mask = tf.to_float(tf.cast(wd_mask, tf.bool))

        t_emb = tf.nn.embedding_lookup(word_embeddings,
                                       t * tf.to_int32(wd_mask))
        t_emb = t_emb * tf.expand_dims(t_mask, -1)

    embed_features = [t_emb]

    if params.enable_bert:
        embed_features.append(features['bert_enc'])

    if params.use_char:
        c = features['c']
        c_mask = tf.to_float(tf.cast(c, tf.bool))

        c = tf.reshape(c, [-1, tf.shape(c)[-1]])
        c_mask = tf.reshape(c_mask, [-1, tf.shape(c_mask)[-1]])

        with tf.device('/cpu:0'):
            char_embeddings = tf.get_variable(
                'char_embeddings',
                shape=(params.char_vocab.size(), params.char_embed_size),
                initializer=tf.glorot_uniform_initializer(),
                trainable=True)
            with tf.variable_scope('char_embedding'):
                c_emb = tf.nn.embedding_lookup(char_embeddings, c)
                c_emb = util.valid_apply_dropout(c_emb, 0.5 * params.dropout)

        with tf.variable_scope("char_encoding", reuse=tf.AUTO_REUSE):
            c_emb = c_emb * tf.expand_dims(c_mask, -1)

            c_shp = util.shape_list(features['c'])
            c_emb = tf.reshape(
                c_emb, [c_shp[0], c_shp[1], c_shp[2], params.char_embed_size])

            c_state = func.linear(tf.reduce_max(c_emb, 2),
                                  params.char_embed_size,
                                  scope="cmap")

        embed_features.append(c_state)

    t_emb = tf.concat(embed_features, axis=2) * tf.expand_dims(t_mask, -1)

    features.update({
        't_emb': t_emb,
        't_mask': t_mask,
    })
    return features
Exemplo n.º 19
0
    def _step_fn(time, bsstate):
        """one expansion step of beam search process"""

        # 1. feed previous predictions, and get the next probabilities
        # generating beam * vocab_size predictions
        prev_seq, prev_log_probs, prev_scores = bsstate.inputs

        flat_prev_seqs = util.merge_neighbor_dims(prev_seq, axis=0)
        flat_prev_state = nest.map_structure(
            lambda x: util.merge_neighbor_dims(x, axis=0), bsstate.state)

        # curr_logits: [batch * beam, vocab_size]
        step_logits, step_state = decoding_fn(flat_prev_seqs[:, -1:],
                                              flat_prev_state, time)
        step_log_probs = util.log_prob_from_logits(step_logits)
        vocab_size = util.shape_list(step_log_probs)[-1]

        # force decoding
        eos_mask = tf.to_float(tf.equal(tf.range(vocab_size), eos_id))
        step_log_probs = tf.cond(
            tf.to_float(time) < tf.to_float(1.),
            lambda: step_log_probs + tf.expand_dims(eos_mask, 0) * -1e9,
            lambda: step_log_probs)

        # expand to [batch, beam, vocab_size]
        step_log_probs = util.unmerge_neighbor_dims(step_log_probs,
                                                    batch_size,
                                                    axis=0)
        step_state = nest.map_structure(
            lambda x: util.unmerge_neighbor_dims(x, batch_size, axis=0),
            step_state)

        # 2. compute top-k scored next predictions
        # reducing beam * vocab_size to 2 * beam
        # [batch, beam, 1] + [batch, beam, vocab_size]
        curr_log_probs = tf.expand_dims(prev_log_probs, 2) + step_log_probs
        length_penality = tf.pow((5.0 + tf.to_float(time + 1)) / 6., alpha)
        curr_scores = curr_log_probs / length_penality

        # [batch, beam * vocab_size]
        curr_flat_scores = util.merge_neighbor_dims(curr_scores, axis=1)
        # [batch, 2 * beam]
        topk_scores, topk_indices = tf.nn.top_k(curr_flat_scores,
                                                2 * beam_size)

        # index manipulation, [batch, 2 * beam]
        curr_beam_indices = topk_indices // vocab_size
        curr_symbol_indices = topk_indices % vocab_size
        beam2_pos = util.batch_coordinates(batch_size, 2 * beam_size)
        curr_coordinates = tf.stack([beam2_pos, curr_beam_indices], axis=2)

        # extract candidate sequences
        # [batch, 2 * beam, time + 1]
        curr_seq = tf.gather_nd(prev_seq, curr_coordinates)
        curr_seq = tf.concat(
            [curr_seq, tf.expand_dims(curr_symbol_indices, 2)], 2)

        # 3. handling alive sequences
        # reducing 2 * beam to beam
        curr_fin_flags = tf.logical_or(
            tf.equal(curr_symbol_indices, eos_id),
            # if time step exceeds the maximum decoding length, should stop
            tf.expand_dims(
                tf.greater_equal(time, tf.to_int32(max_target_length)), 1))
        alive_scores = topk_scores + \
                       tf.to_float(curr_fin_flags) * tf.float32.min
        # [batch, 2 * beam] -> [batch, beam]
        alive_scores, alive_indices = tf.nn.top_k(alive_scores, beam_size)
        beam_pos = util.batch_coordinates(batch_size, beam_size)
        alive_coordinates = tf.stack([beam_pos, alive_indices], axis=2)
        alive_seq = tf.gather_nd(curr_seq, alive_coordinates)
        alive_beam_indices = tf.gather_nd(curr_beam_indices, alive_coordinates)
        beam_coordinates = tf.stack([beam_pos, alive_beam_indices], axis=2)
        alive_state = nest.map_structure(
            lambda x: tf.gather_nd(x, beam_coordinates), step_state)
        alive_log_probs = alive_scores * length_penality

        # 4. handle finished sequences
        # reducing 3 * beam to beam
        prev_fin_seq, prev_fin_scores, prev_fin_flags = bsstate.finish
        # [batch, 2 * beam]
        curr_fin_scores = topk_scores + \
                          (1.0 - tf.to_float(curr_fin_flags)) * tf.float32.min
        # [batch, 3 * beam]
        fin_flags = tf.concat([prev_fin_flags, curr_fin_flags], axis=1)
        fin_scores = tf.concat([prev_fin_scores, curr_fin_scores], axis=1)
        # [batch, beam]
        fin_scores, fin_indices = tf.nn.top_k(fin_scores, beam_size)
        fin_coordinates = tf.stack([beam_pos, fin_indices], axis=2)
        fin_flags = tf.gather_nd(fin_flags, fin_coordinates)
        pad_seq = tf.fill([batch_size, beam_size, 1],
                          tf.constant(pad_id, tf.int32))
        prev_fin_seq = tf.concat([prev_fin_seq, pad_seq], axis=2)
        fin_seq = tf.concat([prev_fin_seq, curr_seq], axis=1)
        fin_seq = tf.gather_nd(fin_seq, fin_coordinates)

        next_state = BeamSearchState(inputs=(alive_seq, alive_log_probs,
                                             alive_scores),
                                     state=alive_state,
                                     finish=(fin_seq, fin_scores, fin_flags))

        return time + 1, next_state
Exemplo n.º 20
0
def decoder(target, state, params):
    mask = tf.to_float(tf.cast(target, tf.bool))
    hidden_size = params.hidden_size

    if 'decoder' not in state:
        target, mask = util.remove_invalid_seq(target, mask)

    embed_name = "embedding" if params.shared_source_target_embedding \
        else "tgt_embedding"
    tgt_emb = tf.get_variable(embed_name,
                              [params.tgt_vocab.size(), params.embed_size])
    tgt_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(tgt_emb, target)
    inputs = tf.nn.bias_add(inputs, tgt_bias)

    # shift
    if 'decoder' not in state:
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
        inputs = inputs[:, :-1, :]
    else:
        inputs = tf.cond(
            tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
            lambda: tf.zeros_like(inputs), lambda: inputs)
        mask = tf.ones_like(mask)

    if util.valid_dropout(params.dropout):
        inputs = tf.nn.dropout(inputs, 1. - params.dropout)

    with tf.variable_scope("decoder"):
        x = inputs
        for layer in range(params.num_decoder_layer):
            with tf.variable_scope("layer_{}".format(layer)):
                init_state = state["decoder_initializer"]["layer_{}".format(
                    layer)]
                if 'decoder' in state:
                    init_state = state["decoder"]["state"]["layer_{}".format(
                        layer)]
                if layer == 0 or params.use_deep_att:
                    returns = rnn.cond_rnn(params.cell,
                                           x,
                                           state["encodes"],
                                           hidden_size,
                                           init_state=init_state,
                                           mask=mask,
                                           num_heads=params.num_heads,
                                           mem_mask=state["mask"],
                                           ln=params.layer_norm,
                                           sm=params.swap_memory,
                                           one2one=False,
                                           dp=params.dropout)
                    (_, hidden_state), (outputs,
                                        _), contexts, attentions = returns
                    c = contexts
                else:
                    if params.caencoder:
                        returns = rnn.cond_rnn(params.cell,
                                               x,
                                               c,
                                               hidden_size,
                                               init_state=init_state,
                                               mask=mask,
                                               mem_mask=mask,
                                               ln=params.layer_norm,
                                               sm=params.swap_memory,
                                               num_heads=params.num_heads,
                                               one2one=True,
                                               dp=params.dropout)
                        (_, hidden_state), (outputs,
                                            _), contexts, attentions = returns
                    else:
                        outputs = rnn.rnn(params.cell,
                                          tf.concat([x, c], -1),
                                          hidden_size,
                                          mask=mask,
                                          init_state=init_state,
                                          ln=params.layer_norm,
                                          sm=params.swap_memory,
                                          dp=params.dropout)
                        outputs, hidden_state = outputs[1]
                if 'decoder' in state:
                    state['decoder']['state']['layer_{}'.format(
                        layer)] = hidden_state

                y = func.linear(outputs, hidden_size, ln=False, scope="ff")

                # short cut via residual connection
                if x.get_shape()[-1].value == y.get_shape()[-1].value:
                    x = func.residual_fn(x, y, dropout=params.dropout)
                else:
                    x = y
                if params.layer_norm:
                    x = func.layer_norm(x, scope="ln")

    feature = func.linear(tf.concat([x, c], -1),
                          params.embed_size,
                          ln=params.layer_norm,
                          scope="ff")
    feature = tf.nn.tanh(feature)

    if util.valid_dropout(params.dropout):
        feature = tf.nn.dropout(feature, 1. - params.dropout)

    if 'dev_decode' in state:
        feature = x[:, -1, :]

    embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
        else "softmax_embedding"
    embed_name = "embedding" if params.shared_source_target_embedding \
        else embed_name
    softmax_emb = tf.get_variable(embed_name,
                                  [params.tgt_vocab.size(), params.embed_size])
    feature = tf.reshape(feature, [-1, params.embed_size])
    logits = tf.matmul(feature, softmax_emb, False, True)

    soft_label, normalizer = util.label_smooth(target,
                                               util.shape_list(logits)[-1],
                                               factor=params.label_smooth)
    centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                          labels=soft_label)
    centropy -= normalizer
    centropy = tf.reshape(centropy, tf.shape(target))

    loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(mask, -1)
    loss = tf.reduce_mean(loss)

    # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
    loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
                   lambda: tf.constant(0, dtype=tf.float32), lambda: loss)

    return loss, logits, state
Exemplo n.º 21
0
def dot_attention(query, memory, mem_mask, hidden_size,
                  ln=False, num_heads=1, cache=None, dropout=None,
                  use_relative_pos=False, max_relative_position=16,
                  out_map=True, scope=None, fuse_mask=None,
                  decode_step=None):
    """
    dotted attention model
    :param query: [batch_size, qey_len, dim]
    :param memory: [batch_size, seq_len, mem_dim] or None
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param out_map: output additional mapping
    :param cache: cache-based decoding
    :param fuse_mask: aan mask during training, and timestep for testing
    :param max_relative_position: maximum position considered for relative embedding
    :param use_relative_pos: whether use relative position information
    :param decode_step: the time step of current decoding, 0-based
    :param scope:
    :return: a value matrix, [batch_size, qey_len, mem_dim]
    """
    with tf.variable_scope(scope or "dot_attention", reuse=tf.AUTO_REUSE,
                           dtype=tf.as_dtype(dtype.floatx())):
        if fuse_mask is not None:
            assert memory is not None, 'Fuse mechanism only applied with cross-attention'
        if cache and use_relative_pos:
            assert decode_step is not None, 'Decode Step must provide when use relative position encoding'

        if memory is None:
            # suppose self-attention from queries alone
            h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
            q, k, v = tf.split(h, 3, -1)

            if cache is not None:
                k = tf.concat([cache['k'], k], axis=1)
                v = tf.concat([cache['v'], v], axis=1)
                cache = {
                    'k': k,
                    'v': v,
                }
        else:
            q = linear(query, hidden_size, ln=ln, scope="q_map")
            if cache is not None and ('mk' in cache and 'mv' in cache):
                k, v = cache['mk'], cache['mv']
            else:
                k = linear(memory, hidden_size, ln=ln, scope="k_map")
                v = linear(memory, hidden_size, ln=ln, scope="v_map")

            if cache is not None:
                cache['mk'] = k
                cache['mv'] = v

        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        q *= (hidden_size // num_heads) ** (-0.5)

        q_shp = util.shape_list(q)
        k_shp = util.shape_list(k)
        v_shp = util.shape_list(v)

        q_len = q_shp[2] if decode_step is None else decode_step + 1
        r_lst = None if decode_step is None else 1

        # q * k => attention weights
        if use_relative_pos:
            r = rpr.get_relative_positions_embeddings(
                q_len, k_shp[2], k_shp[3],
                max_relative_position, name="rpr_keys", last=r_lst)
            logits = rpr.relative_attention_inner(q, k, r, transpose=True)
        else:
            logits = tf.matmul(q, k, transpose_b=True)

        if mem_mask is not None:
            logits += mem_mask

        weights = tf.nn.softmax(logits)

        dweights = util.valid_apply_dropout(weights, dropout)

        # weights * v => attention vectors
        if use_relative_pos:
            r = rpr.get_relative_positions_embeddings(
                q_len, k_shp[2], v_shp[3],
                max_relative_position, name="rpr_values", last=r_lst)
            o = rpr.relative_attention_inner(dweights, v, r, transpose=False)
        else:
            o = tf.matmul(dweights, v)

        o = combine_heads(o)

        if fuse_mask is not None:
            # This is for AAN, the important part is sharing v_map
            v_q = linear(query, hidden_size, ln=ln, scope="v_map")

            if cache is not None and 'aan' in cache:
                aan_o = (v_q + cache['aan']) / dtype.tf_to_float(fuse_mask + 1)
            else:
                # Simplified Average Attention Network
                aan_o = tf.matmul(fuse_mask, v_q)

            if cache is not None:
                if 'aan' not in cache:
                    cache['aan'] = v_q
                else:
                    cache['aan'] = v_q + cache['aan']

            # Directly sum both self-attention and cross attention
            o = o + aan_o

        if out_map:
            o = linear(o, hidden_size, ln=ln, scope="o_map")

        results = {
            'weights': weights,
            'output': o,
            'cache': cache
        }

        return results
Exemplo n.º 22
0
def bert_encoder(sequence, params):

    # extract sequence mask information
    seq_mask = 1. - tf.to_float(tf.equal(sequence, params.bert.vocab.pad))

    # extract segment information
    seg_pos = tf.to_float(tf.equal(sequence, params.bert.vocab.sep))
    seg_ids = tf.cumsum(seg_pos, axis=1, reverse=True)
    seg_num = tf.reduce_sum(seg_pos, axis=1, keepdims=True)
    seg_ids = seg_num - seg_ids
    seg_ids = tf.to_int32(seg_ids * seq_mask)

    # sequence length information
    seq_shp = util.shape_list(sequence)
    batch_size, seq_length = seq_shp[:2]

    def custom_getter(getter, name, *args, **kwargs):
        kwargs['trainable'] = params.tune_bert
        return getter(name, *args, **kwargs)

    with tf.variable_scope("bert", custom_getter=custom_getter):

        # handling sequence embeddings: token_embedding pls segment embedding pls positional embedding
        embed_initializer = tf.truncated_normal_initializer(stddev=params.bert.initializer_range)
        with tf.variable_scope("embeddings"):
            word_embedding = tf.get_variable(
                name="word_embeddings",
                shape=[params.bert.vocab.size, params.bert.hidden_size],
                initializer=embed_initializer
            )
            seq_embed = tf.nn.embedding_lookup(word_embedding, sequence)

            segment_embedding = tf.get_variable(
                name="token_type_embeddings",
                shape=[2, params.bert.hidden_size],
                initializer=embed_initializer
            )
            seg_embed = tf.nn.embedding_lookup(segment_embedding, seg_ids)

            # word embedding + segment embedding
            seq_embed = seq_embed + seg_embed

            # add position embedding
            assert_op = tf.assert_less_equal(seq_length, params.bert.max_position_embeddings)
            with tf.control_dependencies([assert_op]):
                position_embedding = tf.get_variable(
                    name="position_embeddings",
                    shape=[params.bert.max_position_embeddings, params.bert.hidden_size],
                    initializer=embed_initializer
                )
                pos_embed = position_embedding[:seq_length]

                seq_embed = seq_embed + tf.expand_dims(pos_embed, 0)

            # post-processing, layer norm and segmentation
            seq_embed = tc.layers.layer_norm(
                inputs=seq_embed, begin_norm_axis=-1, begin_params_axis=-1)

            seq_embed = util.valid_apply_dropout(seq_embed, params.bert.hidden_dropout_prob)

        bert_outputs = []

        #  handling sequence encoding with transformer encoder
        with tf.variable_scope("encoder"):
            attention_mask = encoder.create_attention_mask_from_input_mask(
                sequence, seq_mask)

            # Run the stacked transformer.
            # `sequence_output` shape = [batch_size, seq_length, hidden_size].
            all_encoder_layers = encoder.transformer_model(
                input_tensor=seq_embed,
                attention_mask=attention_mask,
                hidden_size=params.bert.hidden_size,
                num_hidden_layers=params.bert.num_hidden_layers,
                num_attention_heads=params.bert.num_attention_heads,
                intermediate_size=params.bert.intermediate_size,
                intermediate_act_fn=encoder.get_activation(params.bert.hidden_act),
                hidden_dropout_prob=params.bert.hidden_dropout_prob,
                attention_probs_dropout_prob=params.bert.attention_probs_dropout_prob,
                initializer_range=params.bert.initializer_range,
                do_return_all_layers=True)

        sequence_output = all_encoder_layers

        bert_outputs.append(sequence_output)

        if params.use_bert_single:
            # The "pooler" converts the encoded sequence tensor of shape
            # [batch_size, seq_length, hidden_size] to a tensor of shape
            # [batch_size, hidden_size]. This is necessary for segment-level
            # (or segment-pair-level) classification tasks where we need a fixed
            # dimensional representation of the segment.
            with tf.variable_scope("pooler"):
                # We "pool" the model by simply taking the hidden state corresponding
                # to the first token. We assume that this has been pre-trained
                first_token_tensor = tf.squeeze(sequence_output[-1][:, 0:1, :], axis=1)
                pooled_output = tf.layers.dense(
                    first_token_tensor,
                    params.bert.hidden_size,
                    activation=tf.tanh,
                    kernel_initializer=embed_initializer)

                bert_outputs.append(pooled_output)

        return bert_outputs
Exemplo n.º 23
0
def deep_att_dec_rnn(cell_name,
                     x,
                     memory,
                     d,
                     init_state=None,
                     mask=None,
                     mem_mask=None,
                     ln=False,
                     sm=True,
                     depth=1,
                     num_heads=1):
    """Self implemented conditional-RNN procedure, supporting mask trick"""
    # cell_name: gru, lstm or atr
    # x: input sequence embedding matrix, [batch, seq_len, dim]
    # memory: the conditional part
    # d: hidden dimension for rnn
    # mask: mask matrix, [batch, seq_len]
    # mem_mask: memory mask matrix, [batch, mem_seq_len]
    # ln: whether use layer normalization
    # init_state: the initial hidden states, for cache purpose
    # sm: whether apply swap memory during rnn scan
    # depth: depth for the decoder in deep attention
    # num_heads: number of attention heads, multi-head attention
    # dp: variational dropout

    in_shape = util.shape_list(x)
    batch_size, time_steps = in_shape[:2]
    mem_shape = util.shape_list(memory)

    cell_lower = rnn.get_cell(cell_name,
                              d,
                              ln=ln,
                              scope="{}_lower".format(cell_name))
    cells_higher = []
    for layer in range(depth):
        cell_higher = rnn.get_cell(cell_name,
                                   d,
                                   ln=ln,
                                   scope="{}_higher_{}".format(
                                       cell_name, layer))
        cells_higher.append(cell_higher)

    if init_state is None:
        init_state = cell_lower.get_init_state(shape=[batch_size])
    if mask is None:
        mask = dtype.tf_to_float(tf.ones([batch_size, time_steps]))
    if mem_mask is None:
        mem_mask = dtype.tf_to_float(tf.ones([batch_size, mem_shape[1]]))

    # prepare projected encodes and inputs
    cache_inputs = cell_lower.fetch_states(x)
    cache_inputs = [tf.transpose(v, [1, 0, 2]) for v in list(cache_inputs)]
    proj_memories = func.linear(memory,
                                mem_shape[-1],
                                bias=False,
                                ln=ln,
                                scope="context_att")

    mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2])
    init_context = dtype.tf_to_float(
        tf.zeros([batch_size, depth, mem_shape[-1]]))
    init_weight = dtype.tf_to_float(
        tf.zeros([batch_size, depth, num_heads, mem_shape[1]]))
    mask_pos = len(cache_inputs)

    def _step_fn(prev, x):
        t, h_, c_, a_ = prev

        m, v = x[mask_pos], x[:mask_pos]

        # the first decoder rnn subcell, composing previous hidden state with the current word embedding
        s_ = cell_lower(h_, v)
        s_ = m * s_ + (1. - m) * h_

        atts, att_ctxs = [], []

        for layer in range(depth):
            # perform attention
            prev_cell = cell_lower if layer == 0 else cells_higher[layer - 1]
            vle = func.additive_attention(
                prev_cell.get_hidden(s_),
                memory,
                mem_mask,
                mem_shape[-1],
                ln=ln,
                num_heads=num_heads,
                proj_memory=proj_memories,
                scope="deep_attention_{}".format(layer))
            a, c = vle['weights'], vle['output']
            atts.append(tf.expand_dims(a, 1))
            att_ctxs.append(tf.expand_dims(c, 1))

            # perform next-level recurrence
            c_c = cells_higher[layer].fetch_states(c)
            ss_ = cells_higher[layer](s_, c_c)
            s_ = m * ss_ + (1. - m) * s_

        h = s_
        a = tf.concat(atts, axis=1)
        c = tf.concat(att_ctxs, axis=1)

        return t + 1, h, c, a

    time = tf.constant(0, dtype=tf.int32, name="time")
    step_states = (time, init_state, init_context, init_weight)
    step_vars = cache_inputs + [mask_ta]

    outputs = tf.scan(_step_fn,
                      step_vars,
                      initializer=step_states,
                      parallel_iterations=32,
                      swap_memory=sm)

    output_ta = outputs[1]
    context_ta = outputs[2]
    attention_ta = outputs[3]

    outputs = tf.transpose(output_ta, [1, 0, 2])
    output_states = outputs[:, -1]
    # batch x target length x depth x mem-dimension
    contexts = tf.transpose(context_ta, [1, 0, 2, 3])
    # batch x num_heads x depth x target length x source length
    attentions = tf.transpose(attention_ta, [1, 3, 2, 0, 4])

    return (outputs, output_states), \
           (cells_higher[-1].get_hidden(outputs), cells_higher[-1].get_hidden(output_states)), \
        contexts, attentions
Exemplo n.º 24
0
def encoder(source, params):
    mask = dtype.tf_to_float(tf.cast(source, tf.bool))
    hidden_size = params.hidden_size
    initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5)

    source, mask = util.remove_invalid_seq(source, mask)

    embed_name = "embedding" if params.shared_source_target_embedding \
        else "src_embedding"
    src_emb = tf.get_variable(embed_name,
                              [params.src_vocab.size(), params.embed_size],
                              initializer=initializer)
    src_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(src_emb, source) * (hidden_size**0.5)
    inputs = tf.nn.bias_add(inputs, src_bias)
    inputs = func.add_timing_signal(inputs)

    inputs = util.valid_apply_dropout(inputs, params.dropout)

    with tf.variable_scope("encoder"):
        x = inputs
        for layer in range(params.num_encoder_layer):
            if params.deep_transformer_init:
                layer_initializer = tf.variance_scaling_initializer(
                    params.initializer_gain * (layer + 1)**-0.5,
                    mode="fan_avg",
                    distribution="uniform")
            else:
                layer_initializer = None
            with tf.variable_scope("layer_{}".format(layer),
                                   initializer=layer_initializer):
                with tf.variable_scope("self_attention"):
                    y = func.dot_attention(x,
                                           None,
                                           func.attention_bias(
                                               mask, "masking"),
                                           hidden_size,
                                           num_heads=params.num_heads,
                                           dropout=params.attention_dropout)

                    y = y['output']
                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

                with tf.variable_scope("feed_forward"):
                    y = func.ffn_layer(
                        x,
                        params.filter_size,
                        hidden_size,
                        dropout=params.relu_dropout,
                    )

                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

    source_encodes = x
    x_shp = util.shape_list(x)

    return {
        "encodes": source_encodes,
        "decoder_initializer": {
            "layer_{}".format(l): {
                # plan aan
                "aan": dtype.tf_to_float(tf.zeros([x_shp[0], 1, hidden_size])),
            }
            for l in range(params.num_decoder_layer)
        },
        "mask": mask
    }
Exemplo n.º 25
0
def decoder(target, state, params):
    mask = dtype.tf_to_float(tf.cast(target, tf.bool))
    hidden_size = params.hidden_size
    initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5)

    is_training = ('decoder' not in state)

    if is_training:
        target, mask = util.remove_invalid_seq(target, mask)

    embed_name = "embedding" if params.shared_source_target_embedding \
        else "tgt_embedding"
    tgt_emb = tf.get_variable(embed_name,
                              [params.tgt_vocab.size(), params.embed_size],
                              initializer=initializer)
    tgt_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5)
    inputs = tf.nn.bias_add(inputs, tgt_bias)

    # shift
    if is_training:
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
        inputs = inputs[:, :-1, :]
        inputs = func.add_timing_signal(inputs)
    else:
        inputs = tf.cond(
            tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
            lambda: tf.zeros_like(inputs), lambda: inputs)
        mask = tf.ones_like(mask)
        inputs = func.add_timing_signal(inputs,
                                        time=dtype.tf_to_float(state['time']))

    inputs = util.valid_apply_dropout(inputs, params.dropout)

    with tf.variable_scope("decoder"):
        x = inputs
        for layer in range(params.num_decoder_layer):
            if params.deep_transformer_init:
                layer_initializer = tf.variance_scaling_initializer(
                    params.initializer_gain * (layer + 1)**-0.5,
                    mode="fan_avg",
                    distribution="uniform")
            else:
                layer_initializer = None
            with tf.variable_scope("layer_{}".format(layer),
                                   initializer=layer_initializer):
                with tf.variable_scope("average_attention"):
                    x_fwds = []
                    for strategy in params.strategies:
                        with tf.variable_scope(strategy):
                            x_fwd = average_attention_strategy(
                                strategy, x, mask, state, layer, params)
                            x_fwds.append(x_fwd)
                    x_fwd = tf.add_n(x_fwds) / len(x_fwds)

                    # FFN activation
                    if params.use_ffn:
                        y = func.ffn_layer(
                            x_fwd,
                            params.filter_size,
                            hidden_size,
                            dropout=params.relu_dropout,
                        )
                    else:
                        y = x_fwd

                    # Gating layer
                    z = func.linear(tf.concat([x, y], axis=-1),
                                    hidden_size * 2,
                                    scope="z_project")
                    i, f = tf.split(z, 2, axis=-1)
                    y = tf.sigmoid(i) * x + tf.sigmoid(f) * y

                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

                with tf.variable_scope("cross_attention"):
                    y = func.dot_attention(
                        x,
                        state['encodes'],
                        func.attention_bias(state['mask'], "masking"),
                        hidden_size,
                        num_heads=params.num_heads,
                        dropout=params.attention_dropout,
                        cache=None if is_training else
                        state['decoder']['state']['layer_{}'.format(layer)])
                    if not is_training:
                        # mk, mv
                        state['decoder']['state']['layer_{}'.format(layer)]\
                            .update(y['cache'])

                    y = y['output']
                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

                with tf.variable_scope("feed_forward"):
                    y = func.ffn_layer(
                        x,
                        params.filter_size,
                        hidden_size,
                        dropout=params.relu_dropout,
                    )

                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)
    feature = x
    if 'dev_decode' in state:
        feature = x[:, -1, :]

    embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
        else "softmax_embedding"
    embed_name = "embedding" if params.shared_source_target_embedding \
        else embed_name
    softmax_emb = tf.get_variable(embed_name,
                                  [params.tgt_vocab.size(), params.embed_size],
                                  initializer=initializer)
    feature = tf.reshape(feature, [-1, params.embed_size])
    logits = tf.matmul(feature, softmax_emb, False, True)

    logits = tf.cast(logits, tf.float32)

    soft_label, normalizer = util.label_smooth(target,
                                               util.shape_list(logits)[-1],
                                               factor=params.label_smooth)
    centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                          labels=soft_label)
    centropy -= normalizer
    centropy = tf.reshape(centropy, tf.shape(target))

    mask = tf.cast(mask, tf.float32)
    per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(
        mask, -1)
    loss = tf.reduce_mean(per_sample_loss)

    # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
    loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
                   lambda: tf.constant(0, dtype=tf.float32), lambda: loss)

    return loss, logits, state, per_sample_loss
Exemplo n.º 26
0
def decoder(target, state, params):
    mask = dtype.tf_to_float(tf.cast(target, tf.bool))
    hidden_size = params.hidden_size

    is_training = ('decoder' not in state)

    # handling target-side word embedding, including shift-padding for training
    embed_name = "embedding" if params.shared_source_target_embedding \
        else "tgt_embedding"
    tgt_emb = tf.get_variable(embed_name,
                              [params.tgt_vocab.size(), params.embed_size])
    tgt_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(tgt_emb, target)
    inputs = tf.nn.bias_add(inputs, tgt_bias)

    # shift
    if is_training:
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
        inputs = inputs[:, :-1, :]
    else:
        inputs = tf.cond(
            tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
            lambda: tf.zeros_like(inputs), lambda: inputs)
        mask = tf.ones_like(mask)

    inputs = util.valid_apply_dropout(inputs, params.dropout)

    with tf.variable_scope("decoder"):
        x = inputs

        init_state = state["decoder_initializer"]["layer"]
        if not is_training:
            init_state = state["decoder"]["state"]["layer"]
        returns = deep_att_dec_rnn(params.cell,
                                   x,
                                   state["encodes"],
                                   hidden_size,
                                   init_state=init_state,
                                   mask=mask,
                                   num_heads=params.num_heads,
                                   mem_mask=state["mask"],
                                   ln=params.layer_norm,
                                   sm=params.swap_memory,
                                   depth=params.num_decoder_layer)
        (_, hidden_state), (outputs, _), contexts, attentions = returns

        if not is_training:
            state['decoder']['state']['layer'] = hidden_state

        x = outputs
        cshp = util.shape_list(contexts)
        c = tf.reshape(contexts, [cshp[0], cshp[1], cshp[2] * cshp[3]])

    feature = func.linear(tf.concat([x, c, inputs], -1),
                          params.embed_size,
                          ln=params.layer_norm,
                          scope="ff")
    feature = tf.nn.tanh(feature)

    feature = util.valid_apply_dropout(feature, params.dropout)

    if 'dev_decode' in state:
        feature = feature[:, -1, :]

    embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
        else "softmax_embedding"
    embed_name = "embedding" if params.shared_source_target_embedding \
        else embed_name
    softmax_emb = tf.get_variable(embed_name,
                                  [params.tgt_vocab.size(), params.embed_size])
    feature = tf.reshape(feature, [-1, params.embed_size])
    logits = tf.matmul(feature, softmax_emb, False, True)

    logits = tf.cast(logits, tf.float32)

    soft_label, normalizer = util.label_smooth(target,
                                               util.shape_list(logits)[-1],
                                               factor=params.label_smooth)
    centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                          labels=soft_label)
    centropy -= normalizer
    centropy = tf.reshape(centropy, tf.shape(target))

    mask = tf.cast(mask, tf.float32)
    per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(
        mask, -1)
    loss = tf.reduce_mean(per_sample_loss)

    # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
    loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
                   lambda: tf.constant(0, dtype=tf.float32), lambda: loss)

    return loss, logits, state, per_sample_loss
Exemplo n.º 27
0
def dot_attention(query,
                  memory,
                  mem_mask,
                  hidden_size,
                  ln=False,
                  num_heads=1,
                  cache=None,
                  dropout=None,
                  use_relative_pos=True,
                  max_relative_position=16,
                  out_map=True,
                  scope=None):
    """
    dotted attention model
    :param query: [batch_size, qey_len, dim]
    :param memory: [batch_size, seq_len, mem_dim] or None
    :param mem_mask: [batch_size, seq_len]
    :param hidden_size: attention space dimension
    :param ln: whether use layer normalization
    :param num_heads: attention head number
    :param dropout: attention dropout, default disable
    :param out_map: output additional mapping
    :param cache: cache-based decoding
    :param max_relative_position: maximum position considered for relative embedding
    :param use_relative_pos: whether use relative position information
    :param scope:
    :return: a value matrix, [batch_size, qey_len, mem_dim]
    """
    with tf.variable_scope(scope or "dot_attention"):
        if memory is None:
            # suppose self-attention from queries alone
            h = linear(query, hidden_size * 3, ln=ln, scope="qkv_map")
            q, k, v = tf.split(h, 3, -1)

            if cache is not None:
                k = tf.concat([cache['k'], k], axis=1)
                v = tf.concat([cache['v'], v], axis=1)
                cache = {
                    'k': k,
                    'v': v,
                }
        else:
            q = linear(query, hidden_size, ln=ln, scope="q_map")
            if cache is not None and ('mk' in cache and 'mv' in cache):
                k, v = cache['mk'], cache['mv']
            else:
                h = linear(memory, hidden_size * 2, ln=ln, scope="kv_map")
                k, v = tf.split(h, 2, -1)

            if cache is not None:
                cache['mk'] = k
                cache['mv'] = v

        q = split_heads(q, num_heads)
        k = split_heads(k, num_heads)
        v = split_heads(v, num_heads)

        q *= (hidden_size // num_heads)**(-0.5)

        q_shp = util.shape_list(q)
        k_shp = util.shape_list(k)
        v_shp = util.shape_list(v)

        # q * k => attention weights
        if use_relative_pos:
            r = get_relative_positions_embeddings(
                q_shp[2],
                k_shp[2],
                k_shp[3],
                max_relative_position,
                name="relative_positions_keys")
            logits = relative_attention_inner(q, k, r, transpose=True)
        else:
            logits = tf.matmul(q, k, transpose_b=True)

        if mem_mask is not None:
            logits += mem_mask

        weights = tf.nn.softmax(logits)

        weights = util.valid_apply_dropout(weights, dropout)

        # weights * v => attention vectors
        if use_relative_pos:
            r = get_relative_positions_embeddings(
                q_shp[2],
                k_shp[2],
                v_shp[3],
                max_relative_position,
                name="relative_positions_values")
            o = relative_attention_inner(weights, v, r, transpose=False)
        else:
            o = tf.matmul(weights, v)

        o = combine_heads(o)

        if out_map:
            o = linear(o, hidden_size, ln=ln, scope="o_map")

        results = {'weights': weights, 'output': o, 'cache': cache}

        return results
Exemplo n.º 28
0
def cond_rnn(cell_name,
             x,
             memory,
             d,
             init_state=None,
             mask=None,
             mem_mask=None,
             ln=False,
             sm=True,
             one2one=False):
    """Self implemented conditional-RNN procedure, supporting mask trick"""
    # cell_name: gru, lstm or atr
    # x: input sequence embedding matrix, [batch, seq_len, dim]
    # memory: the conditional part
    # d: hidden dimension for rnn
    # mask: mask matrix, [batch, seq_len]
    # mem_mask: memory mask matrix, [batch, mem_seq_len]
    # ln: whether use layer normalization
    # init_state: the initial hidden states, for cache purpose
    # sm: whether apply swap memory during rnn scan
    # one2one: whether the memory is one-to-one mapping for x

    in_shape = util.shape_list(x)
    batch_size, time_steps = in_shape[:2]
    mem_shape = util.shape_list(memory)

    cell_lower = get_cell(cell_name,
                          d,
                          ln=ln,
                          scope="{}_lower".format(cell_name))
    cell_higher = get_cell(cell_name,
                           d,
                           ln=ln,
                           scope="{}_higher".format(cell_name))

    if init_state is None:
        init_state = cell_lower.get_init_state(shape=[batch_size])
    if mask is None:
        mask = tf.ones([batch_size, time_steps], tf.float32)
    if mem_mask is None:
        mem_mask = tf.ones([batch_size, mem_shape[1]], tf.float32)

    # prepare projected encodes and inputs
    cache_inputs = cell_lower.fetch_states(x)
    cache_inputs = [tf.transpose(v, [1, 0, 2]) for v in list(cache_inputs)]
    if not one2one:
        proj_memories = linear(memory,
                               mem_shape[-1],
                               bias=False,
                               ln=ln,
                               scope="context_att")
    else:
        cache_memories = cell_higher.fetch_states(memory)
        cache_memories = [
            tf.transpose(v, [1, 0, 2]) for v in list(cache_memories)
        ]
    mask_ta = tf.transpose(tf.expand_dims(mask, -1), [1, 0, 2])
    init_context = tf.zeros([batch_size, mem_shape[-1]], tf.float32)
    init_weight = tf.zeros([batch_size, mem_shape[1]], tf.float32)
    mask_pos = len(cache_inputs)

    def _step_fn(prev, x):
        t, h_, c_, a_ = prev

        if not one2one:
            m, v = x[mask_pos], x[:mask_pos]
        else:
            c, c_c, m, v = x[-1], x[mask_pos + 1:-1], x[mask_pos], x[:mask_pos]

        s = cell_lower(h_, v)
        s = m * s + (1. - m) * h_

        if not one2one:
            a, c = additive_attention(cell_lower.get_hidden(s),
                                      memory,
                                      mem_mask,
                                      mem_shape[-1],
                                      ln=ln,
                                      proj_memory=proj_memories,
                                      scope="attention")
            c_c = cell_higher.fetch_states(c)
        else:
            a = tf.tile(tf.expand_dims(tf.range(time_steps), 0),
                        [batch_size, 1])
            a = tf.to_float(a == t)
            a = tf.reshape(a, tf.shape(init_weight))

        h = cell_higher(s, c_c)
        h = m * h + (1. - m) * s

        return t + 1, h, c, a

    time = tf.constant(0, dtype=tf.int32, name="time")
    step_states = (time, init_state, init_context, init_weight)
    step_vars = cache_inputs + [mask_ta]
    if one2one:
        step_vars += cache_memories + [memory]

    outputs = tf.scan(_step_fn,
                      step_vars,
                      initializer=step_states,
                      parallel_iterations=32,
                      swap_memory=sm)

    output_ta = outputs[1]
    context_ta = outputs[2]
    attention_ta = outputs[3]

    outputs = tf.transpose(output_ta, [1, 0, 2])
    output_states = outputs[:, -1]
    contexts = tf.transpose(context_ta, [1, 0, 2])
    attentions = tf.transpose(attention_ta, [1, 0, 2])

    return (outputs, output_states), \
           (cell_higher.get_hidden(outputs), cell_higher.get_hidden(output_states)), \
           contexts, attentions
Exemplo n.º 29
0
def decoder(target, state, params):
    mask = dtype.tf_to_float(tf.cast(target, tf.bool))
    hidden_size = params.hidden_size
    initializer = tf.random_normal_initializer(0.0, hidden_size**-0.5)

    is_training = ('decoder' not in state)

    if is_training:
        target, mask = util.remove_invalid_seq(target, mask)

    embed_name = "embedding" if params.shared_source_target_embedding \
        else "tgt_embedding"
    tgt_emb = tf.get_variable(embed_name,
                              [params.tgt_vocab.size(), params.embed_size],
                              initializer=initializer)
    tgt_bias = tf.get_variable("bias", [params.embed_size])

    inputs = tf.gather(tgt_emb, target) * (hidden_size**0.5)
    inputs = tf.nn.bias_add(inputs, tgt_bias)

    # shift
    if is_training:
        inputs = tf.pad(inputs, [[0, 0], [1, 0], [0, 0]])
        inputs = inputs[:, :-1, :]
        inputs = func.add_timing_signal(inputs)
    else:
        inputs = tf.cond(
            tf.reduce_all(tf.equal(target, params.tgt_vocab.pad())),
            lambda: tf.zeros_like(inputs), lambda: inputs)
        mask = tf.ones_like(mask)
        inputs = func.add_timing_signal(inputs,
                                        time=dtype.tf_to_float(state['time']))

    inputs = util.valid_apply_dropout(inputs, params.dropout)

    # Applying L0Drop
    # --------
    source_memory = state["encodes"]
    source_mask = state["mask"]

    # source_pruning: log alpha_i = x_i w^T
    source_pruning = func.linear(source_memory, 1, scope="source_pruning")

    if is_training:  # training
        source_memory, l0_mask = l0norm.var_train(
            (source_memory, source_pruning))
        l0_norm_loss = tf.squeeze(l0norm.l0_norm(source_pruning), -1)
        l0_norm_loss = tf.reduce_sum(l0_norm_loss * source_mask,
                                     -1) / tf.reduce_sum(source_mask, -1)
        l0_norm_loss = tf.reduce_mean(l0_norm_loss)
        l0_norm_loss = l0norm.l0_regularization_loss(
            l0_norm_loss,
            reg_scalar=params.l0_norm_reg_scalar,
            start_reg_ramp_up=params.l0_norm_start_reg_ramp_up,
            end_reg_ramp_up=params.l0_norm_end_reg_ramp_up,
            warm_up=params.l0_norm_warm_up,
        )

        # force the model to only attend to unmasked position
        source_mask = dtype.tf_to_float(
            tf.cast(tf.squeeze(l0_mask, -1), tf.bool)) * source_mask
    else:  # evaluation
        source_memory, l0_mask = l0norm.var_eval(
            (source_memory, source_pruning))
        l0_norm_loss = 0.0

        source_memory, source_mask, count_mask = extract_encodes(
            source_memory, source_mask, l0_mask)
        count_mask = tf.expand_dims(tf.expand_dims(count_mask, 1), 1)
    # --------

    with tf.variable_scope("decoder"):
        x = inputs
        for layer in range(params.num_decoder_layer):
            if params.deep_transformer_init:
                layer_initializer = tf.variance_scaling_initializer(
                    params.initializer_gain * (layer + 1)**-0.5,
                    mode="fan_avg",
                    distribution="uniform")
            else:
                layer_initializer = None
            with tf.variable_scope("layer_{}".format(layer),
                                   initializer=layer_initializer):
                with tf.variable_scope("self_attention"):
                    y = func.dot_attention(
                        x,
                        None,
                        func.attention_bias(tf.shape(mask)[1], "causal"),
                        hidden_size,
                        num_heads=params.num_heads,
                        dropout=params.attention_dropout,
                        cache=None if is_training else
                        state['decoder']['state']['layer_{}'.format(layer)])
                    if not is_training:
                        # k, v
                        state['decoder']['state']['layer_{}'.format(layer)] \
                            .update(y['cache'])

                    y = y['output']
                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

                with tf.variable_scope("cross_attention"):
                    if is_training:
                        y = func.dot_attention(
                            x,
                            source_memory,
                            func.attention_bias(source_mask, "masking"),
                            hidden_size,
                            num_heads=params.num_heads,
                            dropout=params.attention_dropout,
                        )
                    else:
                        y = dot_attention(x,
                                          source_memory,
                                          func.attention_bias(
                                              source_mask, "masking"),
                                          hidden_size,
                                          count_mask=count_mask,
                                          num_heads=params.num_heads,
                                          dropout=params.attention_dropout,
                                          cache=state['decoder']['state'][
                                              'layer_{}'.format(layer)])

                        # mk, mv
                        state['decoder']['state']['layer_{}'.format(layer)] \
                            .update(y['cache'])

                    y = y['output']
                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)

                with tf.variable_scope("feed_forward"):
                    y = func.ffn_layer(
                        x,
                        params.filter_size,
                        hidden_size,
                        dropout=params.relu_dropout,
                    )

                    x = func.residual_fn(x, y, dropout=params.residual_dropout)
                    x = func.layer_norm(x)
    feature = x
    if 'dev_decode' in state:
        feature = x[:, -1, :]

    embed_name = "tgt_embedding" if params.shared_target_softmax_embedding \
        else "softmax_embedding"
    embed_name = "embedding" if params.shared_source_target_embedding \
        else embed_name
    softmax_emb = tf.get_variable(embed_name,
                                  [params.tgt_vocab.size(), params.embed_size],
                                  initializer=initializer)
    feature = tf.reshape(feature, [-1, params.embed_size])
    logits = tf.matmul(feature, softmax_emb, False, True)

    logits = tf.cast(logits, tf.float32)

    soft_label, normalizer = util.label_smooth(target,
                                               util.shape_list(logits)[-1],
                                               factor=params.label_smooth)
    centropy = tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits,
                                                          labels=soft_label)
    centropy -= normalizer
    centropy = tf.reshape(centropy, tf.shape(target))

    mask = tf.cast(mask, tf.float32)
    per_sample_loss = tf.reduce_sum(centropy * mask, -1) / tf.reduce_sum(
        mask, -1)
    loss = tf.reduce_mean(per_sample_loss)

    loss = loss + l0_norm_loss

    # these mask tricks mainly used to deal with zero shapes, such as [0, 1]
    loss = tf.cond(tf.equal(tf.shape(target)[0], 0),
                   lambda: tf.constant(0, tf.float32), lambda: loss)

    return loss, logits, state, per_sample_loss