Beispiel #1
0
def multi_head_w_global(
        x, scope, n_state, n_head, train=None, scale=False, resid_dropout=0.9, attn_dropout=0.9,
        use_global=False, use_direction=False, b=None, global_afn='exp',
):
    assert n_state % n_head == 0
    with tf.variable_scope(scope):
        sl = shape_list(x)[-2]
        if not use_direction:
            b = tf.matrix_band_part(tf.ones([sl, sl]), -1, 0)  # Lower triangular part.
            b = tf.reshape(b, [1, 1, sl, sl])

        c = conv1d(x, 'c_attn_openai_trans', n_state * 3, 1, train=train)
        q, k, v = tf.split(c, 3, 2)
        q = split_heads(q, n_head)  # bs,hd,sl,d
        k = split_heads(k, n_head, k=True)  # bs,hd,d,sl
        v = split_heads(v, n_head)  # bs,hd,sl,d

        # 1. t2t
        w = tf.matmul(q, k)  # bs,hd,sl, sl
        if scale:
            n_state_hd = shape_list(v)[-1]
            w = w * tf.rsqrt(tf.cast(n_state_hd, tf.float32))

        if use_global:
            e_w = activation_name_to_func(global_afn)(w) * b

            # 2. s2t
            w_g = split_heads(conv1d(x, "c_w_g", n_state, 1, train=train), n_head)  # bs,hd,sl,d
            e_w_g = tf.exp(w_g)  # # bs,hd,sl,d

            # 3. mtsa
            accum_z_deno = tf.matmul(e_w, e_w_g)  # bs,hd,sl,dim
            accum_z_deno = tf.where(  # in case of NaN and Inf
                tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)),
                accum_z_deno,
                tf.ones_like(accum_z_deno)
            )
            e_w = dropout(e_w, math.sqrt(attn_dropout), train)
            e_w_g = dropout(e_w_g, math.sqrt(attn_dropout), train)
            rep_mul_score = v * e_w_g
            accum_rep_mul_score = tf.matmul(e_w, rep_mul_score)
            a = accum_rep_mul_score / accum_z_deno
        else:
            w = w * b + -1e9 * (1 - b)
            w = tf.nn.softmax(w)
            w = w * b  # fixed the bug
            w = dropout(w, attn_dropout, train)  # attention dropout
            a = tf.matmul(w, v)

        a = merge_heads(a)
        a = conv1d(a, 'c_proj_openai_trans', n_state, 1, train=train)
        a = dropout(a, resid_dropout, train, )
        return a
Beispiel #2
0
def norm(x, scope, axis=[-1]):  # read
    with tf.variable_scope(scope):
        n_state = shape_list(x)[-1]
        g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1))
        b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0))
        g, b = get_ema_vars(g, b)
        return _norm(x, g, b, axis=axis)
Beispiel #3
0
def mlp(x, scope, n_state, train=None, afn='gelu', resid_dropout=0.9):  #  read: 3layer mlp
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        act = act_name2fn(afn)
        h = act(conv1d(x, 'c_fc_openai_trans', n_state, 1, train=train))
        h2 = conv1d(h, 'c_proj_openai_trans', nx, 1, train=train)
        h2 = dropout(h2, resid_dropout, train)
        return h2
Beispiel #4
0
def get_transformer_clf_features(inp_emb, inp_token, clf_token):
    assert_rank(inp_emb, 3)
    with tf.name_scope('get_transformer_clf_features'):
        bs, sl, embd_dim = shape_list(inp_emb)

        bs_idxs = tf.range(0, bs)  # [bs]
        sent_idxs = tf.argmax(tf.cast(tf.equal(inp_token, clf_token), tf.float32), -1)  # bs
        feature_idxs = tf.stack([bs_idxs, tf.cast(sent_idxs, tf.int32)], -1)  # [bs,2]
        return tf.gather_nd(inp_emb, feature_idxs)  # [bs,inp_dim]
Beispiel #5
0
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), pad='VALID', train=None):
    with tf.variable_scope(scope):
        nx = shape_list(x)[-1]
        w = tf.get_variable("w", [rf, nx, nf], initializer=w_init)
        b = tf.get_variable("b", [nf], initializer=b_init)
        if rf == 1: #faster 1x1 conv: use 'matmul' or 'conv1d'
            c = tf.reshape(tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf]))+b, shape_list(x)[:-1]+[nf])
        else: #was used to train LM
            c = tf.nn.conv1d(x, w, stride=1, padding=pad)+b
        return c
Beispiel #6
0
def multi_head_block_openai(
        x, scope, train=None, scale=False, n_head=12, afn='gelu',
        resid_dropout=0.9, attn_dropout=0.9, reuse=None):
    assert_rank(x, 3)
    with tf.variable_scope(scope, reuse=reuse):
        nx = shape_list(x)[-1]
        a = multi_head(
            x, 'attn', nx, n_head, train=train, scale=scale,
            resid_dropout=resid_dropout, attn_dropout=attn_dropout, )
        n = norm(x + a, 'ln_1_openai_trans')
        m = mlp(n, 'mlp', nx * 4, train=train, afn=afn, resid_dropout=resid_dropout)
        h = norm(n + m, 'ln_2_openai_trans')
        return h
Beispiel #7
0
def _attn(q, k, v, train=None, scale=False, attn_dropout=0.9):  # read
    w = tf.matmul(q, k)

    if scale:
        n_state = shape_list(v)[-1]
        w = w*tf.rsqrt(tf.cast(n_state, tf.float32))

    w = mask_attn_weights(w)  # highlight, this is uni-directional self-attention
    w = tf.nn.softmax(w)

    # w = tf.Print(w, [tf.shape(w)])

    w = dropout(w, attn_dropout, train)  # attention dropout

    a = tf.matmul(w, v)
    return a
Beispiel #8
0
def multi_head_block(
        x, scope, train=None, scale=False, n_head=12, afn='gelu',
        resid_dropout=0.9, attn_dropout=0.9, reuse=None,
        use_global=False, use_direction=False, x_mask=None, global_afn='exp', attn_self=True,
):
    assert_rank(x, 3)
    with tf.variable_scope(scope, reuse=reuse):
        # b gene
        if use_direction:
            b = tf.transpose(mask_ft_generation(x_mask, n_head, True, attn_self=attn_self), [1, 0, 2, 3])
        else:
            b = None

        nx = shape_list(x)[-1]

        a = multi_head_w_global(
            x, 'attn', nx, n_head, train=train, scale=scale,
            resid_dropout=resid_dropout, attn_dropout=attn_dropout,
            use_global=use_global, use_direction=use_direction, b=b, global_afn=global_afn
        )
        n = norm(x+a, 'ln_1_openai_trans')
        m = mlp(n, 'mlp', nx*4, train=train, afn=afn, resid_dropout=resid_dropout)
        h = norm(n+m, 'ln_2_openai_trans')
        return h
Beispiel #9
0
def merge_states(x):  # read
    x_shape = shape_list(x)
    new_x_shape = x_shape[:-2]+[np.prod(x_shape[-2:])]
    return tf.reshape(x, new_x_shape)
Beispiel #10
0
def split_states(x, n):  # read
    x_shape = shape_list(x)
    m = x_shape[-1]
    new_x_shape = x_shape[:-1]+[n, m//n]
    return tf.reshape(x, new_x_shape)
Beispiel #11
0
def mask_attn_weights(w):  # read
    n = shape_list(w)[-1]
    b = tf.matrix_band_part(tf.ones([n, n]), -1, 0)  # Lower triangular part.
    b = tf.reshape(b, [1, 1, n, n])
    w = w*b + -1e9*(1-b)
    return w