def merge_heads(x, scope): """ (batch, head, pixel, head_state) -> (batch, pixel, state) """ with tf.name_scope(scope): return merge_states(transpose_0213(x))
def testBlocksparseTransformerDense(self): with self.test_session(config=config) as sess, tf.device("/gpu:0"): for bsize in (16, 32, 64): layout = np.ones([heads, ctx, ctx], dtype=np.bool) bst = trans.BlocksparseTransformer(layout, block_size=bsize) shape = (batch, ctx * bsize, heads * state) if ones: cpuQ = np.ones(shape, dtype=np.float32) cpuK = np.ones(shape, dtype=np.float32) cpuV = np.ones(shape, dtype=np.float32) cpuE = np.ones(shape, dtype=np.float32) else: cpuQ = np.random.uniform(-1.0, 1.0, shape).astype( np.float16).astype(np.float32) cpuK = np.random.uniform(-1.0, 1.0, shape).astype( np.float16).astype(np.float32) cpuV = np.random.uniform(-1.0, 1.0, shape).astype( np.float16).astype(np.float32) cpuE = np.random.uniform(-1.0, 1.0, shape).astype( np.float16).astype(np.float32) q = tf.placeholder(tf.float32, shape) k = tf.placeholder(tf.float32, shape) v = tf.placeholder(tf.float32, shape) e = tf.placeholder(tf.float32, shape) feed_dict = {q: cpuQ, k: cpuK, v: cpuV, e: cpuE} qf = ew.float_cast(q, dtype=tf.float16) kf = ew.float_cast(k, dtype=tf.float16) vf = ew.float_cast(v, dtype=tf.float16) w = bst.query_key_op(qf, kf) w = bst.softmax(w, scale=scale) y = bst.weight_value_op(w, vf) qf = trans.transpose_0213( tf.reshape(qf, [batch, ctx * bsize, heads, state])) kf = trans.transpose_0213( tf.reshape(kf, [batch, ctx * bsize, heads, state])) vf = trans.transpose_0213( tf.reshape(vf, [batch, ctx * bsize, heads, state])) W = tf.matmul(qf, kf, transpose_b=True) W = trans.softmax(W, scale=scale) Y = tf.matmul(W, vf) Y = tf.reshape(trans.transpose_0213(Y), [batch, ctx * bsize, heads * state]) y = ew.float_cast(y, dtype=tf.float32) Y = ew.float_cast(Y, dtype=tf.float32) y, (dq, dk, dv) = sess.run([y, tf.gradients(y, [q, k, v], e)], feed_dict) Y, (DQ, DK, DV) = sess.run([Y, tf.gradients(Y, [q, k, v], e)], feed_dict) print("testBlocksparseTransformerDense", bsize) if not bench: for op, dev, cpu in [ [" Y", y, Y], ["DV", dv, DV], ["DK", dk, DK], ["DQ", dq, DQ], ]: self.compare_results(op, dev, cpu)
def split_heads(x, n, scope): """ (batch, pixel, state) -> (batch, head, pixel, head_state) """ with tf.name_scope(scope): return transpose_0213(split_states(x, n))