def block(x, scope, train=False, scale=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] a = attn(x, 'attn', nx, n_head, train=train, scale=scale) n = norm(x + a, 'ln_1') m = mlp(n, 'mlp', nx * 4, train=train) h = norm(n + m, 'ln_2') return h
def norm(x, scope, axis=[-1]): with tf.variable_scope(scope): n_state = shape_list(x)[-1] g = tf.get_variable("g", [n_state], initializer=tf.constant_initializer(1)) b = tf.get_variable("b", [n_state], initializer=tf.constant_initializer(0)) return _norm(x, g, b, axis=axis)
def mlp(x, scope, n_state, train=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] act = act_fns[afn] h = act(conv1d(x, 'c_fc', n_state, 1, train=train)) h2 = conv1d(h, 'c_proj', nx, 1, train=train) h2 = dropout(h2, resid_pdrop, train) return h2
def model(X, M, Y, train=False, reuse=False): with tf.variable_scope('model', reuse=reuse): we = tf.get_variable( "we", [n_vocab + n_special + n_ctx, n_embd], initializer=tf.random_normal_initializer(stddev=0.02)) we = dropout(we, embd_pdrop, train) X = tf.reshape(X, [-1, n_ctx, 2]) M = tf.reshape(M, [-1, n_ctx]) h = embed(X, we) for layer in range(n_layer): h = block(h, 'h%d' % layer, train=train, scale=True) lm_h = tf.reshape(h[:, :-1], [-1, n_embd]) lm_logits = tf.matmul(lm_h, we, transpose_b=True) lm_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=lm_logits, labels=tf.reshape(X[:, 1:, 0], [-1])) lm_losses = tf.reshape( lm_losses, [shape_list(X)[0], shape_list(X)[1] - 1]) lm_losses = tf.reduce_sum(lm_losses * M[:, 1:], 1) / tf.reduce_sum( M[:, 1:], 1) clf_h = tf.reshape(h, [-1, n_embd]) pool_idx = tf.cast( tf.argmax(tf.cast(tf.equal(X[:, :, 0], clf_token), tf.float32), 1), tf.int32) clf_h = tf.gather( clf_h, tf.range(shape_list(X)[0], dtype=tf.int32) * n_ctx + pool_idx) clf_h = tf.reshape(clf_h, [-1, 2, n_embd]) if train and clf_pdrop > 0: shape = shape_list(clf_h) shape[1] = 1 clf_h = tf.nn.dropout(clf_h, 1 - clf_pdrop, shape) clf_h = tf.reshape(clf_h, [-1, n_embd]) clf_logits = clf(clf_h, 1, train=train) clf_logits = tf.reshape(clf_logits, [-1, 2]) clf_losses = tf.nn.sparse_softmax_cross_entropy_with_logits( logits=clf_logits, labels=Y) return clf_logits, clf_losses, lm_losses
def clf(x, ny, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), train=False): with tf.variable_scope('clf'): nx = shape_list(x)[-1] w = tf.get_variable("w", [nx, ny], initializer=w_init) b = tf.get_variable("b", [ny], initializer=b_init) return tf.matmul(x, w) + b
def conv1d(x, scope, nf, rf, w_init=tf.random_normal_initializer(stddev=0.02), b_init=tf.constant_initializer(0), pad='VALID', train=False): with tf.variable_scope(scope): nx = shape_list(x)[-1] w = tf.get_variable("w", [rf, nx, nf], initializer=w_init) b = tf.get_variable("b", [nf], initializer=b_init) if rf == 1: # faster 1x1 conv c = tf.reshape( tf.matmul(tf.reshape(x, [-1, nx]), tf.reshape(w, [-1, nf])) + b, shape_list(x)[:-1] + [nf]) else: # was used to train LM c = tf.nn.conv1d(x, w, stride=1, padding=pad) + b return c
def _attn(q, k, v, train=False, scale=False): w = tf.matmul(q, k) if scale: n_state = shape_list(v)[-1] w = w * tf.rsqrt(tf.cast(n_state, tf.float32)) w = mask_attn_weights(w) w = tf.nn.softmax(w) w = dropout(w, attn_pdrop, train) a = tf.matmul(w, v) return a
def split_states(x, n): x_shape = shape_list(x) m = x_shape[-1] new_x_shape = x_shape[:-1] + [n, m // n] return tf.reshape(x, new_x_shape)
def mask_attn_weights(w): n = shape_list(w)[-1] b = tf.matrix_band_part(tf.ones([n, n]), -1, 0) b = tf.reshape(b, [1, 1, n, n]) w = w * b + -1e9 * (1 - b) return w
def merge_states(x): x_shape = shape_list(x) new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])] return tf.reshape(x, new_x_shape)