def forward(self, x, weights): return np.take(weights, x, axis=0)
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would # also work, but these helpers include performance optimizations for TPU. o = permute_via_gather(so, undo_sort, sticker, axis=0) logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state
def permute_impl(val): return np.take(val, permutation, axis=axis)
def vjpfun(permuted_grad): # JAX autodiff would synthesize a scatter operation because it doesn't # know that the indices are a permutatation. However on TPU, gathers are # faster than scatters (at least in the regime the LSH attention uses). return (np.take(permuted_grad, inverse_permutation, axis=axis), )
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state