def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2)
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, ignored_k, v = inputs seqlen = qk.shape[-2] # qk/v are n_batch*n_heads, seqlen, d_head # bins are n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int((self.n_bins + 1) * seqlen) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2) sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2) if ct is not None: so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2) @jax.jit def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so @jax.jit def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name so, vjpfun = jax.vjp(binned_attn, sqk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sqk, sv) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2) v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2) return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2)
def forward_and_vjp(self, inputs, ct, params=(), **kwargs): del params, kwargs q, k, v = inputs # q/k/v are n_batch, n_heads, seqlen, d_head assert k.shape[2] % self.n_bins == 0 bin_size = int(k.shape[2] // self.n_bins) # q_bins/kv_bins are n_batch, n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. For # now, instead of hashing we just put consecutive items in the same bucket. q_bins = np.arange(q.shape[2], dtype=np.int32) // bin_size q_bins = jax.lax.tie_in(q, q_bins) q_bins = q_bins[None, None, :] q_bins = np.broadcast_to(q_bins, q.shape[:-1]) q_bins = -q_bins kv_bins = q_bins * 2 # q_t/kv_t are n_batch, n_heads, seqlen q_t = jax.lax.tie_in(q, np.arange(q.shape[2])) q_t = np.reshape(q_t, (1, 1, q_t.shape[0])) q_t = np.broadcast_to(q_t, q.shape[:-1]) kv_t = q_t def chunk_rank3(x): return np.reshape(x, (x.shape[0], x.shape[1], self.n_bins, -1)) def chunk_rank4(x): return np.reshape( x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1])) def unchunk_rank4(x): return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1])) # Sort everything by bin number (variables starting with "s" are sorted) _, sq_t = jax.lax.sort_key_val(q_bins, q_t, dimension=2) sq = np.take_along_axis(q, sq_t[:, :, :, None], axis=2) if ct is not None: so_ct = np.take_along_axis(ct, sq_t[:, :, :, None], axis=2) _, skv_t = jax.lax.sort_key_val(kv_bins, kv_t, dimension=2) sk = np.take_along_axis(k, skv_t[:, :, :, None], axis=2) sv = np.take_along_axis(v, skv_t[:, :, :, None], axis=2) @jax.jit def binned_attn(sq, sk, sv): """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = chunk_rank3(sq_t) bkv_t = chunk_rank3(skv_t) bq = chunk_rank4(sq) bk = chunk_rank4(sk) bv = chunk_rank4(sv) dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]), np.float32) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) bo = np.matmul(dots, bv) so = unchunk_rank4(bo) return so @jax.jit def binned_attn_vjp(sq, sk, sv, so_ct): so, vjpfun = jax.vjp(binned_attn, sq, sk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sq, sk, sv) _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2) out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sq_ct, sk_ct, sv_ct) = binned_attn_vjp(sq, sk, sv, so_ct) _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2) out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2) q_ct = np.take_along_axis(sq_ct, undo_q_sort[:, :, :, None], axis=2) _, undo_kv_sort = jax.lax.sort_key_val(skv_t, kv_t, dimension=2) k_ct = np.take_along_axis(sk_ct, undo_kv_sort[:, :, :, None], axis=2) v_ct = np.take_along_axis(sv_ct, undo_kv_sort[:, :, :, None], axis=2) return out, (q_ct, k_ct, v_ct)