def multi_head_w_global( # Added and Modified by xxx xxx x, scope, n_state, n_head, train=None, scale=False, resid_dropout=0.9, attn_dropout=0.9, use_global=False, use_direction=False, b=None, global_afn='exp', ): assert n_state % n_head == 0 with tf.variable_scope(scope): sl = shape_list(x)[-2] if not use_direction: b = tf.matrix_band_part(tf.ones([sl, sl]), -1, 0) # Lower triangular part. b = tf.reshape(b, [1, 1, sl, sl]) c = conv1d(x, 'c_attn_openai_trans', n_state * 3, 1, train=train) q, k, v = tf.split(c, 3, 2) q = split_heads(q, n_head) # bs,hd,sl,d k = split_heads(k, n_head, k=True) # bs,hd,d,sl v = split_heads(v, n_head) # bs,hd,sl,d # 1. t2t w = tf.matmul(q, k) # bs,hd,sl, sl if scale: n_state_hd = shape_list(v)[-1] w = w * tf.rsqrt(tf.cast(n_state_hd, tf.float32)) if use_global: e_w = activation_name_to_func(global_afn)(w) * b # 2. s2t w_g = split_heads(conv1d(x, "c_w_g", n_state, 1, train=train), n_head) # bs,hd,sl,d e_w_g = tf.exp(w_g) # # bs,hd,sl,d # 3. mtsa accum_z_deno = tf.matmul(e_w, e_w_g) # bs,hd,sl,dim accum_z_deno = tf.where( # in case of NaN and Inf tf.greater(accum_z_deno, tf.zeros_like(accum_z_deno)), accum_z_deno, tf.ones_like(accum_z_deno) ) e_w = dropout(e_w, math.sqrt(attn_dropout), train) e_w_g = dropout(e_w_g, math.sqrt(attn_dropout), train) rep_mul_score = v * e_w_g accum_rep_mul_score = tf.matmul(e_w, rep_mul_score) a = accum_rep_mul_score / accum_z_deno else: w = w * b + -1e9 * (1 - b) w = tf.nn.softmax(w) w = w * b # fixed the bug w = dropout(w, attn_dropout, train) # attention dropout a = tf.matmul(w, v) a = merge_heads(a) a = conv1d(a, 'c_proj_openai_trans', n_state, 1, train=train) a = dropout(a, resid_dropout, train, ) return a
def norm(x, scope, axis=[-1]): # read with tf.variable_scope(scope): n_state = shape_list(x)[-1] g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) g, b = get_ema_vars(g, b) return _norm(x, g, b, axis=axis)
def mlp(x, scope, n_state, train=None, afn='gelu', resid_dropout=0.9): # read: 3layer mlp with tf.variable_scope(scope): nx = shape_list(x)[-1] act = act_name2fn(afn) h = act(conv1d(x, 'c_fc_openai_trans', n_state, 1, train=train)) h2 = conv1d(h, 'c_proj_openai_trans', nx, 1, train=train) h2 = dropout(h2, resid_dropout, train) return h2
def get_transformer_clf_features(inp_emb, inp_token, clf_token): assert_rank(inp_emb, 3) with tf.name_scope('get_transformer_clf_features'): bs, sl, embd_dim = shape_list(inp_emb) bs_idxs = tf.range(0, bs) # [bs] sent_idxs = tf.argmax(tf.cast(tf.equal(inp_token, clf_token), tf.float32), -1) # bs feature_idxs = tf.stack([bs_idxs, tf.cast(sent_idxs, tf.int32)], -1) # [bs,2] return tf.gather_nd(inp_emb, feature_idxs) # [bs,inp_dim]
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), pad='VALID', train=None): with tf.variable_scope(scope): nx = shape_list(x)[-1] w = tf.get_variable("w", [rf, nx, nf], initializer=w_init) b = tf.get_variable("b", [nf], initializer=b_init) if rf == 1: # faster 1x1 conv: use 'matmul' or 'conv1d' c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, shape_list(x)[:-1] + [nf]) else: # was used to train LM c = tf.nn.conv1d(x, w, stride=1, padding=pad) + b return c
def _attn(q, k, v, train=None, scale=False, attn_dropout=0.9): # read w = tf.matmul(q, k) if scale: n_state = shape_list(v)[-1] w = w*tf.rsqrt(tf.cast(n_state, tf.float32)) w = mask_attn_weights(w) # highlight, this is uni-directional self-attention w = tf.nn.softmax(w) # w = tf.Print(w, [tf.shape(w)]) w = dropout(w, attn_dropout, train) # attention dropout a = tf.matmul(w, v) return a
def multi_head_block( x, scope, train=None, scale=False, n_head=12, afn='gelu', resid_dropout=0.9, attn_dropout=0.9, reuse=None, use_global=False, use_direction=False, x_mask=None, global_afn='exp', attn_self=True, ): assert_rank(x, 3) with tf.variable_scope(scope, reuse=reuse): # b gene if use_direction: b = tf.transpose(mask_ft_generation(x_mask, n_head, True, attn_self=attn_self), [1, 0, 2, 3]) else: b = None nx = shape_list(x)[-1] a = multi_head_w_global( x, 'attn', nx, n_head, train=train, scale=scale, resid_dropout=resid_dropout, attn_dropout=attn_dropout, use_global=use_global, use_direction=use_direction, b=b, global_afn=global_afn ) n = norm(x+a, 'ln_1_openai_trans') m = mlp(n, 'mlp', nx*4, train=train, afn=afn, resid_dropout=resid_dropout) h = norm(n+m, 'ln_2_openai_trans') return h
def merge_states(x): # read x_shape = shape_list(x) new_x_shape = x_shape[:-2]+[np.prod(x_shape[-2:])] return tf.reshape(x, new_x_shape)
def split_states(x, n): # read x_shape = shape_list(x) m = x_shape[-1] new_x_shape = x_shape[:-1]+[n, m//n] return tf.reshape(x, new_x_shape)
def mask_attn_weights(w): # read n = shape_list(w)[-1] b = tf.matrix_band_part(tf.ones([n, n]), -1, 0) # Lower triangular part. b = tf.reshape(b, [1, 1, n, n]) w = w*b + -1e9*(1-b) return w