def make_mask(N, M, k): x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask
def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name """Make a n+1 dim one-hot array from n dim int-categorical array.""" arange_size = np.arange(size) if backend.get_name() == 'jax': # Work around a jax broadcasting issue. arange_size = jax.lax.tie_in(x, arange_size) return np.array(x[..., np.newaxis] == arange_size, dtype)
def make_mask(N, M, k): # pylint: disable=invalid-name """Constructs a slice of the causal attention mask. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = np.arange(N, dtype=np.int32) y = np.arange(M, dtype=np.int32) mask = jax.lax.lt((jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0, )) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask
def bin_vectors_by_time(self, vecs): seqlen = vecs.shape[-2] assert seqlen % self.n_bins == 0 bin_size = int(seqlen // self.n_bins) bins = np.arange(seqlen, dtype=np.int32) // bin_size bins = jax.lax.tie_in(vecs, bins) bins = bins[None, :] bins = np.broadcast_to(bins, vecs.shape[:-1]) return bins
def make_self_mask(N, M, k): # pylint: disable=invalid-name """Masks out elements attending to self. Args: N: number of query positions M: number of key positions k: position of the initial query element Returns: N x M mask, where 1.0 indicates that attention is not allowed. """ x = jax.lax.tie_in(k, np.arange(N, dtype=np.int32)) y = jax.lax.tie_in(k, np.arange(M, dtype=np.int32)) mask = jax.lax.eq( (jax.lax.broadcast_in_dim( x, shape=(N, M), broadcast_dimensions=(0,)) + k), jax.lax.broadcast(y, [N])) mask = jax.lax.convert_element_type(mask, np.float32) return mask
def label_smoothed_loss(logpred, target, size, padding_idx=0, smoothing=0.0): """Returns a label-smoothing loss-criterion function.""" confidence = 1.0 - smoothing zerosmoothed = smoothing / (size - 2) delta = confidence - zerosmoothed assert logpred.shape[1] == size truedist = (np.full_like(logpred, zerosmoothed) + delta * slax.one_hot(target, size)) truedist *= (1 - (np.arange(size) == padding_idx)) truedist *= (1 - (target == padding_idx))[:, np.newaxis] return kl_div(logpred, truedist, eps=1e-6)
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 random_rotations_shape = ( vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, self.n_buckets // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = jax.random.normal( rng, random_rotations_shape).astype('float32') # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) if self._rehash_each_round: buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1,)) else: # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val( rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,)) return buckets
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def PreparePairedSequenceBatch(source, target_in, pad=0): """Build masks for this batch. Args: source: (batch, source_len) array of integer-coded symbols for inputs target_in: (batch, batch_len) array of integer-coded symbols for targets pad: int: the padding symbol used to pad the above Returns: Prepared batch of tuple of arrays: source, input-target, shifted-target, source mask, target mask, source-target "memory" mask, minibatch token count """ target = target_in[:, :-1] target_y = target_in[:, 1:] source_mask = np.reshape(source != pad, (source.shape[0], 1, 1, source.shape[-1])) target_mask = MakeTargetMask(target, pad) memory_mask = (np.reshape( np.arange(target.shape[-1]) < source.shape[-1], [-1, 1])) ntokens = np.sum(target_y != pad) return (source, target, target_y, source_mask, target_mask, memory_mask, ntokens)
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 eps = 1e-5 rng = backend.random.get_prng(0) inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 v1 = 47.9167 layer = normalization.BatchNorm(axis=(0, 1, 2)) params, state = layer.initialize(input_shape, input_dtype, rng) onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 0) self.assertEqual(state[2], 0) out, state = layer(inp1, params, state) onp.testing.assert_allclose(state[0], m1) onp.testing.assert_allclose(state[1], v1, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6) inp2 = inp1 * 2 + 3 m2 = m1 * 2 + 3 v2 = v1 * 4 m12 = (m1 + m2) / 2 v12 = (v1 + v2) / 2 out, state = layer(inp2, params, state) onp.testing.assert_allclose(state[0], m12) onp.testing.assert_allclose(state[1], v12, rtol=1e-6) self.assertEqual(state[2], 2) onp.testing.assert_allclose(out, (inp2 - m2) / np.sqrt(v2 + eps), rtol=1e-6) layer = normalization.BatchNorm(axis=(0, 1, 2), mode="eval") inp3 = inp1 * 5 + 7 out, state_unchanged = layer(inp3, params, state) for i in range(3): onp.testing.assert_allclose(state_unchanged[i], state[i]) onp.testing.assert_allclose(out, (inp3 - m12) / np.sqrt(v12 + eps), rtol=1e-6)
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params, kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, _, v = inputs seqlen = qk.shape[-2] # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops? qk = np.tile(qk, (self.n_hashes, 1, 1)) v = np.tile(v, (self.n_hashes, 1, 1)) # bins are n_hashes*n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_hashes*n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int( (self.n_buckets_per_bin * self.n_bins + 1) * seqlen ) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) # TODO(kitaev): why does jax flag integer indices as differentiable? # If we don't call stop_gradient here, custom gradients below won't work # because the primitive functions close over "differentiable" variables. sjoint_t = jax.lax.stop_gradient(sjoint_t) undo_sort = jax.lax.stop_gradient(undo_sort) # The backward pass of gather is in general a scatter operation, but we know # we're dealing with permutations so we use gather for the backward pass # too. This custom gradient should be about 2x faster than having jax infer # one that uses scatter ops instead. def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2) def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2) @jax.custom_transforms def permute(vecs): return permute_impl(vecs) def permute_vjp(vecs): out_vecs = permute_impl(vecs) def vjpfun(grad): return (unpermute_impl(grad), ) return out_vecs, vjpfun @jax.custom_transforms def unpermute(vecs): return unpermute_impl(vecs) def unpermute_vjp(vecs): out_vecs = unpermute_impl(vecs) def vjpfun(grad): return (permute_impl(grad), ) return out_vecs, vjpfun jax.defvjp_all(permute, permute_vjp) jax.defvjp_all(unpermute, unpermute_vjp) sqk = permute(qk) sv = permute(v) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Sum to re-normalize. dots_logsumexp += np.log(dots_sum) # Add it to the weight. dots /= dots_sum # Re-normalize. bo = np.matmul(dots, bv) so = unchunk_vectors(bo) slogits = unchunk_vectors(dots_logsumexp) o = unpermute(so) logits = unpermute(slogits) o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == inputs[2].shape return out, state
def call_and_grad(self, inputs, ct, rng=None, **kwargs): del kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, ignored_k, v = inputs seqlen = qk.shape[-2] # qk/v are n_batch*n_heads, seqlen, d_head # bins are n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int((self.n_bins + 1) * seqlen) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) sqk = np.take_along_axis(qk, sjoint_t[:, :, None], axis=-2) sv = np.take_along_axis(v, sjoint_t[:, :, None], axis=-2) if ct is not None: so_ct = np.take_along_axis(ct, sjoint_t[:, :, None], axis=-2) @jax.jit def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so @jax.jit def binned_attn_vjp(sqk, sv, so_ct): # pylint: disable=invalid-name so, vjpfun = jax.vjp(binned_attn, sqk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sqk, sv) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sqk_ct, sv_ct) = binned_attn_vjp(sqk, sv, so_ct) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) out = np.take_along_axis(so, undo_sort[:, :, None], axis=-2) qk_ct = np.take_along_axis(sqk_ct, undo_sort[:, :, None], axis=-2) v_ct = np.take_along_axis(sv_ct, undo_sort[:, :, None], axis=-2) return out, (qk_ct, np.zeros_like(ignored_k), v_ct)
def forward_and_vjp(self, inputs, ct, params=(), **kwargs): del params, kwargs q, k, v = inputs # q/k/v are n_batch, n_heads, seqlen, d_head assert k.shape[2] % self.n_bins == 0 bin_size = int(k.shape[2] // self.n_bins) # q_bins/kv_bins are n_batch, n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. For # now, instead of hashing we just put consecutive items in the same bucket. q_bins = np.arange(q.shape[2], dtype=np.int32) // bin_size q_bins = jax.lax.tie_in(q, q_bins) q_bins = q_bins[None, None, :] q_bins = np.broadcast_to(q_bins, q.shape[:-1]) q_bins = -q_bins kv_bins = q_bins * 2 # q_t/kv_t are n_batch, n_heads, seqlen q_t = jax.lax.tie_in(q, np.arange(q.shape[2])) q_t = np.reshape(q_t, (1, 1, q_t.shape[0])) q_t = np.broadcast_to(q_t, q.shape[:-1]) kv_t = q_t def chunk_rank3(x): return np.reshape(x, (x.shape[0], x.shape[1], self.n_bins, -1)) def chunk_rank4(x): return np.reshape( x, (x.shape[0], x.shape[1], self.n_bins, -1, x.shape[-1])) def unchunk_rank4(x): return np.reshape(x, (x.shape[0], x.shape[1], -1, x.shape[-1])) # Sort everything by bin number (variables starting with "s" are sorted) _, sq_t = jax.lax.sort_key_val(q_bins, q_t, dimension=2) sq = np.take_along_axis(q, sq_t[:, :, :, None], axis=2) if ct is not None: so_ct = np.take_along_axis(ct, sq_t[:, :, :, None], axis=2) _, skv_t = jax.lax.sort_key_val(kv_bins, kv_t, dimension=2) sk = np.take_along_axis(k, skv_t[:, :, :, None], axis=2) sv = np.take_along_axis(v, skv_t[:, :, :, None], axis=2) @jax.jit def binned_attn(sq, sk, sv): """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = chunk_rank3(sq_t) bkv_t = chunk_rank3(skv_t) bq = chunk_rank4(sq) bk = chunk_rank4(sk) bv = chunk_rank4(sv) dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, :, None], bkv_t[:, :, :, None, :]), np.float32) dots = dots - 1e9 * mask # Softmax. dots = np.exp(dots - dots.max(axis=-1, keepdims=True)) dots = dots / dots.sum(axis=-1, keepdims=True) bo = np.matmul(dots, bv) so = unchunk_rank4(bo) return so @jax.jit def binned_attn_vjp(sq, sk, sv, so_ct): so, vjpfun = jax.vjp(binned_attn, sq, sk, sv) sqkv_ct = vjpfun(so_ct) return so, sqkv_ct if ct is None: so = binned_attn(sq, sk, sv) _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2) out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2) return out, None else: # Jax can construct a backward pass automatically, but it's about 2x # slower than writing our own. The main reason is that the backward pass # of gather is in general a scatter operation, but we know we're dealing # with permutations so we use gather for the backward pass too. so, (sq_ct, sk_ct, sv_ct) = binned_attn_vjp(sq, sk, sv, so_ct) _, undo_q_sort = jax.lax.sort_key_val(sq_t, q_t, dimension=2) out = np.take_along_axis(so, undo_q_sort[:, :, :, None], axis=2) q_ct = np.take_along_axis(sq_ct, undo_q_sort[:, :, :, None], axis=2) _, undo_kv_sort = jax.lax.sort_key_val(skv_t, kv_t, dimension=2) k_ct = np.take_along_axis(sk_ct, undo_kv_sort[:, :, :, None], axis=2) v_ct = np.take_along_axis(sv_ct, undo_kv_sort[:, :, :, None], axis=2) return out, (q_ct, k_ct, v_ct)
def one_hot(x, size, dtype=np.float32): """Make a n+1 dim one-hot array from n dim int-categorical array.""" return np.array(x[..., np.newaxis] == np.arange(size), dtype)
def _forward_train_eval(self, inputs, rng): (inputs, original_len, n_bins) = self._pad_inputs(inputs) q, k, v = inputs seqlen = q.shape[-2] # q/k/v are n_batch*n_heads, seqlen, d_head # Time indices for causal masking. t = jax.lax.tie_in(q, np.arange(seqlen)) # Split off a "bin" axis for chunks of consecutive items. bq_t = np.reshape(t, (n_bins, -1)) bq = np.reshape(q, (q.shape[0], n_bins, -1, q.shape[-1])) if self._share_qk: bk = self.make_unit_length(bq) else: bk = np.reshape(k, (k.shape[0], n_bins, -1, k.shape[-1])) bv = np.reshape(v, (v.shape[0], n_bins, -1, v.shape[-1])) # Allow each chunk to attend within itself, and also one chunk back. def look_one_back(x): # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis. if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) return np.concatenate([x, x_extra], axis=1) else: assert len(x.shape) == 4 x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1) return np.concatenate([x, x_extra], axis=2) bkv_t = look_one_back(bq_t) bk = look_one_back(bk) bv = look_one_back(bv) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking based on the time indices. mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[None, :, :, None], bkv_t[None, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. if self._share_qk: self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 1e5 * self_mask if self.dropout > 0.0: # Dropout is broadcast across the batch+head dimension dropout_shape = (1, dots.shape[-3], dots.shape[-2], dots.shape[-1]) keep_prob = jax.lax.tie_in(dots, 1.0 - self.dropout) keep = backend.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) output = np.reshape(bo, (bo.shape[0], -1, bo.shape[-1])) assert output.shape == v.shape return output[..., :original_len, :]
def single_call(self, qk, v, buckets, hash_rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val( buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis(np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to(nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to( b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum( jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1,)) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val( buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out