def blocksparse_attention_impl(q, k, v, heads, attn_mode, local_attn_ctx=None, blocksize=32, num_verts=None, vertsize=None): global bst_dict n_ctx = shape_list(q)[1] assert shape_list(v)[1]%n_ctx == 0 if attn_mode == 'strided': # Strided attention is implemented on the transposed matrix to provide greater block sparsity q = strided_transpose(q, n_ctx, local_attn_ctx, blocksize) k = strided_transpose(k, n_ctx, local_attn_ctx, blocksize) v = strided_transpose(v, n_ctx, local_attn_ctx, blocksize) n_state = shape_list(q)[-1] // heads key = f'{local_attn_ctx}' + f'{n_ctx}' + attn_mode if key not in bst_dict: bst_dict[key]= get_blocksparse_obj(n_ctx, heads, attn_mode, blocksize, local_attn_ctx, num_verts, vertsize, shape_list(v)[1]//n_ctx - 1) bst = bst_dict[key] scale_amount = tf.cast(1.0 / np.sqrt(n_state), tf.float32) w = bst.query_key_op(q, k) w = bst.masked_softmax(w, scale=scale_amount) a = bst.weight_value_op(w, v) if attn_mode == 'strided': n, t, embd = shape_list(a) bT_ctx = n_ctx // local_attn_ctx a = tf.reshape(a, [n, local_attn_ctx, bT_ctx, embd]) a = tf.transpose(a, [0, 2, 1, 3]) a = tf.reshape(a, [n, t, embd]) return a
def merge_states(x): """ reshape (batch, pixel, head, head_state) -> (batch, pixel, state) """ x_shape = shape_list(x) new_x_shape = x_shape[:-2] + [np.prod(x_shape[-2:])] return tf.reshape(x, new_x_shape)
def split_states(x, n): """ reshape (batch, pixel, state) -> (batch, pixel, head, head_state) """ x_shape = shape_list(x) m = x_shape[-1] new_x_shape = x_shape[:-1] + [n, m // n] return tf.reshape(x, new_x_shape)
def strided_transpose(x, n_ctx, local_attn_ctx, blocksize): bT_ctx = n_ctx // local_attn_ctx assert bT_ctx % blocksize == 0, f'{bT_ctx}, {blocksize}' n, t, embd = shape_list(x) x = tf.reshape(x, [n, bT_ctx, local_attn_ctx, embd]) x = tf.transpose(x, [0, 2, 1, 3]) x = tf.reshape(x, [n, t, embd]) return x
def attention_impl(q, k, v, heads, attn_mode, local_attn_ctx=None): q = split_heads(q, heads) k = split_heads(k, heads) v = split_heads(v, heads) n_timesteps = shape_list(k)[2] mask = tf.to_float(get_attn_mask(n_timesteps, attn_mode, local_attn_ctx)) w = tf.matmul(q, k, transpose_b=True) scale_amount = 1.0 / np.sqrt(shape_list(q)[-1]) orig_dtype = q.dtype if orig_dtype == tf.float16: w = tf.cast(w, tf.float32) w = w * scale_amount w = w * mask + -1e9 * (1 - mask) w = tf.nn.softmax(w) w = tf.cast(w, orig_dtype) a = tf.matmul(w, v) a = merge_heads(a) return a