def SumLearnedPick(positions): """Get a pair (vec, pos) and pick new pos.""" succ_keys = positions[:-1, :] succ_values = positions[1:, :] subtract_1_keys = positions[1:, :] subtract_1_values = positions[:-1, :] l = int(positions.shape[0]) // 2 add_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for i in range(l) for j in range(l) ]) add_values = np.array( [positions[i + j, :] for i in range(l) for j in range(l)]) # TODO(lukaszkaiser): try this below: "for j in range(i) for i in range(2*l)" sub_keys = np.array([ np.concatenate([positions[i, :], positions[j, :]]) for j in range(l) for i in range(l) ]) sub_values = np.array( [positions[max(i - j, 0), :] for j in range(l) for i in range(l)]) return tl.Serial( tl.Dup(), tl.Dup(), tl.Dup(), tl.Dup(), tl.Parallel( LearnedQP(), LearnedQP(keys=succ_keys, values=succ_values), LearnedQP(keys=subtract_1_keys, values=subtract_1_values), LearnedQP(keys=add_keys, values=add_values, binary=True), LearnedQP(keys=sub_keys, values=sub_values, binary=True), ), Unnest(), SoftmaxBranches(n_branches=5))
def look_one_back(x): # Output: pairs [ bin_i bin_{i-1} ] concatenated on the time axis. if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) return np.concatenate([x, x_extra], axis=1) else: assert len(x.shape) == 4 x_extra = np.concatenate([x[:, -1:, :, :], x[:, :-1, :, :]], axis=1) return np.concatenate([x, x_extra], axis=2)
def hash_vectors(self, vecs, rng): if self.bin_by_time: # Instead of hashing, put chunks of consecutive items in the same bin. # This exists as a sanity check for the other parts of this class. return self.bin_vectors_by_time(vecs) # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each batch element, head, and # (crucially) each round of hashing. All of these are part of dimension 0 # of vecs. Applying multiple hashes to the same input is important because # it increases the probability of being in the same bin as relevant items. n_buckets = self.n_buckets_per_bin * self.n_bins assert n_buckets % 2 == 0 rot_rng = rng if self._one_rng: rot_rng = jax.lax.tie_in(vecs, self._prng) random_rotation = jax.random.normal( rot_rng, (vecs.shape[0], vecs.shape[-1], n_buckets // 2)).astype('float32') # TODO(kitaev): making the vectors unit-length here is probably redundant. # vecs = self.make_unit_length(vecs) rng, subrng = backend.random.split(rng) vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.matmul(vecs, random_rotation) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) bins = np.argmax(rotated_vecs, axis=-1) return bins
def NewPositionalEncoding(x, positions=None, **kwargs): """Implements new positional encoding.""" del kwargs x_length = np.shape(x)[1] pos = np.array(positions)[np.newaxis, :x_length, :] pos += np.zeros((np.shape(x)[0], 1, 1)) # Broadcast on batch. res = np.concatenate([x, pos], axis=2) return res
def ConcatenateN(xs, params, n=2, axis=-1, **kwargs): """Concatenate first N inputs (and output remainder as is if non-empty).""" del params, kwargs res = np.concatenate(xs[:n], axis) rest = list(xs[n:]) if rest: return tuple([res] + rest) return res
def CopyHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x, p) into x_h1, p_h1, x_h2, p_h2, ....""" head_size = (x.shape[2] - h * POS_VECTOR_SIZE) // h p = x[:, :, -h * POS_VECTOR_SIZE:] res, idx = [], 0 for i in range(h): res.append(x[:, :, idx:idx + head_size]) res.append(p[:, :, i * POS_VECTOR_SIZE:(i + 1) * POS_VECTOR_SIZE]) idx += head_size return np.concatenate(res, axis=-1)
def MixHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x0, p) into x0_h1, p, x0_h2, p, ....""" head_size = (x.shape[2] - POS_VECTOR_SIZE) // h p = x[:, :, -POS_VECTOR_SIZE:] res, idx = [], 0 for _ in range(h): res.append(x[:, :, idx:idx + head_size]) res.append(p) idx += head_size return np.concatenate(res, axis=-1)
def QueryPositionKV(x, keys=None, values=None, binary=False, **unused_kwargs): """Query a table with a position vector.""" if keys is None: return x k = np.array(keys) v = np.array(values) q = x if binary: q = np.concatenate([x, x], axis=-1) return tl.DotProductAttention(q, k, v, None, None, None, None)
def ChunkedAttentionSelector(x, params, selector=None, **kwargs): """Select which chunks to attend to in chunked attention. Args: x: inputs, a list of elements of the form (q, k, v), mask for each chunk. params: parameters (unused). selector: a function from chunk_number -> list of chunk numbers that says which other chunks should be appended to the given one (previous if None). **kwargs: unused other arguments. Returns: a list of elements of the form (q, k', v', mask') where k', v' and mask' are concatenations of k, v and identity-extended masks from selected chunks. """ del params, kwargs selector = selector or (lambda x: [] if x < 1 else [x - 1]) triples, masks = zip(*x) (queries, keys, values) = zip(*triples) result = [] for i in range(len(x)): selected = selector(i) # Since keys and values are [batch, length, depth] we concatenate on axis=1. # We also always include the current key or value at the end. new_key_list = [keys[j] for j in selected] new_key = np.concatenate(new_key_list + [keys[i]], axis=1) new_value = np.concatenate([values[j] for j in selected] + [values[i]], axis=1) # Masks are (1, query-len, key-len) so we concatenate on axis=2. new_mask_shapes = [(1, queries[i].shape[1], key.shape[1]) for key in new_key_list] cur_mask = masks[i] # Masks are all-1 for the added chunks (no masking). new_mask_list = [ np.ones(s, dtype=cur_mask.dtype) for s in new_mask_shapes ] # We still use the current (often causal) mask for the final chunk. new_mask = np.concatenate(new_mask_list + [cur_mask], axis=2) result.append((queries[i], new_key, new_value, new_mask)) return tuple(result)
def DiagonalGate(x, params, **kwargs): """Split channels in 3 parts. Shifts 1st and 3rd sections to left/right.""" del params del kwargs # x : [batch, 1, length, depth] x = np.pad( x, [(0, 0), (0, 0), (1, 1), (0, 0)], mode='constant', constant_values=0.0) depth = x.shape[-1] // 3 assert 3 * depth == x.shape[-1], ('Depth must be divisible by 3', depth, x.shape) xs = [ x[:, :, :-2, :depth], x[:, :, 1:-1, depth:2 * depth], x[:, :, 2:, 2 * depth:3 * depth] ] return np.concatenate(xs, axis=3)
def hash_vectors(self, vecs, rng): if self.bin_by_time: # Instead of hashing, put chunks of consecutive items in the same bin. # This exists as a sanity check for the other parts of this class. return self.bin_vectors_by_time(vecs) # See https://arxiv.org/pdf/1509.02897.pdf assert self.n_bins % 2 == 0 random_rotation = jax.random.normal( rng, (vecs.shape[-1], self.n_bins//2)).astype('float32') # TODO(kitaev): making the vectors unit-length here is probably redundant. vecs = self.make_unit_length(vecs) rotated_vecs = np.matmul(vecs, random_rotation) rotated_vecs = self.make_unit_length(rotated_vecs) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) bins = np.argmax(rotated_vecs, axis=-1) return bins
def ShiftRight(x, **unused_kwargs): """Layer to shift the tensor to the right by padding on axis 1.""" if not isinstance(x, (list, tuple)): # non-chunked inputs pad_widths = [(0, 0), (1, 0)] padded = np.pad(x, pad_widths, mode='constant') return padded[:, :-1] # Handling chunked inputs. Recall that the list of chunks represents a big # sequence (the concatenation of the chunks). We want to shift that sequence, # so we put a 0 in the beginning of the first chunk and the last element of # that chunk is used as the new first element of the next chunk, and so on. padded = [] last_value = np.zeros_like(x[0][:, -1]) for chunk in x: padded_chunk = np.concatenate([last_value[:, np.newaxis], chunk], axis=1) last_value = chunk[:, -1] padded.append(padded_chunk[:, :-1]) return padded
def binned_attn(sqk, sv): # pylint: disable=invalid-name """Performs attention on sorted queries/keys/values.""" # Split off a "bin" axis so that attention only occurs whithin chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt( bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots = np.exp(dots - backend.logsumexp(dots, axis=-1, keepdims=True)) bo = np.matmul(dots, bv) so = unchunk_vectors(bo) return so
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. assert self.n_buckets % 2 == 0 random_rotations_shape = ( vecs.shape[-1], self.n_hashes if self._rehash_each_round else 1, self.n_buckets // 2) rng = jax.lax.tie_in(vecs, rng) rng, subrng = backend.random.split(rng) random_rotations = jax.random.normal( rng, random_rotations_shape).astype('float32') # TODO(lukaszkaiser): the dropout mask will be used for all rounds of # hashing, so it's shared between them. Check if that's what we want. dropped_vecs = self.drop_for_hash(vecs, subrng) rotated_vecs = np.einsum('tf,fhb->htb', dropped_vecs, random_rotations) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) if self._rehash_each_round: buckets = np.argmax(rotated_vecs, axis=-1) # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * self.n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1,)) else: # In this configuration, we map each item to the top self.n_hashes buckets rotated_vecs = np.squeeze(rotated_vecs, 0) bucket_range = jax.lax.tie_in(vecs, np.arange(rotated_vecs.shape[-1])) bucket_range = np.reshape(bucket_range, (1, -1)) bucket_range = np.broadcast_to(bucket_range, rotated_vecs.shape) _, buckets = jax.lax.sort_key_val( rotated_vecs, bucket_range, dimension=-1) buckets = buckets[:, -self.n_hashes:] buckets = np.reshape(np.moveaxis(buckets, 0, -1), (-1,)) return buckets
def CombineHeadsPos(x, h=8, **unused_kwargs): """Mix x = (x0, p0, ..., xH, pH) into x0, ...., xH, p_combined. The positions are added as vectors. Args: x: input vector, concatenated (x0, p0, ..., xH, pH). h: number of heads. Returns: the vector with combined positions. """ head_size = int((x.shape[2] / h) - POS_VECTOR_SIZE) res, positions, idx = [], [], 0 for _ in range(h): res.append(x[:, :, idx:idx + head_size]) idx += head_size positions.append(x[:, :, idx:idx + POS_VECTOR_SIZE]) idx += POS_VECTOR_SIZE combined_position = sum(positions) res.append(combined_position) return np.concatenate(res, axis=-1)
def hash_vectors(self, vecs, rng): if self.bin_by_time: # Instead of hashing, put chunks of consecutive items in the same bin. # This exists as a sanity check for the other parts of this class. return self.bin_vectors_by_time(vecs) # See https://arxiv.org/pdf/1509.02897.pdf # It's not clear whether sampling a different random rotation for each head # and batch element matters here, but see MergedMultiHashedCausalAttention. assert self.n_bins % 2 == 0 rot_rng = rng if self._one_rng: rot_rng = jax.lax.tie_in(vecs, self._prng) random_rotation = jax.random.normal( rot_rng, (vecs.shape[0], vecs.shape[-1], self.n_bins // 2)).astype('float32') # TODO(kitaev): making the vectors unit-length here is probably redundant. vecs = self.make_unit_length(vecs) rotated_vecs = np.matmul(vecs, random_rotation) rotated_vecs = self.make_unit_length(rotated_vecs) rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) bins = np.argmax(rotated_vecs, axis=-1) return bins
def call(self, inputs, params=(), state=(), rng=None, **kwargs): del params, kwargs # We use the same vector as both a query and a key. For now we haven't # adjusted any of the surrounding code, so we still get a separate "key" # input that we ignore. qk, _, v = inputs seqlen = qk.shape[-2] # qk/v are n_hashes*n_batch*n_heads, seqlen, d_head # TODO(kitaev): is it faster to fuse this tiling into gather/scatter ops? qk = np.tile(qk, (self.n_hashes, 1, 1)) v = np.tile(v, (self.n_hashes, 1, 1)) # bins are n_hashes*n_batch*n_heads, seqlen # They specify which hash bucket the query/key/value vectors fall in. bins = self.hash_vectors(qk, rng=rng) # joint_t is n_hashes*n_batch*n_heads, seqlen joint_t = jax.lax.tie_in(qk, np.arange(seqlen)) joint_t = np.reshape(joint_t, (1, seqlen)) joint_t = np.broadcast_to(joint_t, qk.shape[:-1]) assert int( (self.n_buckets_per_bin * self.n_bins + 1) * seqlen ) < 2**31, ( 'Potential 32-bit integer overflow; please double-check the code.') joint_bins_and_t = seqlen * bins + joint_t def chunk_scalars(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1)) def chunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], self.n_bins, -1, x.shape[-1])) def unchunk_vectors(x): # pylint: disable=invalid-name return np.reshape(x, (x.shape[0], -1, x.shape[-1])) # Sort everything by bin number, with a secondary sort by time # (variables starting with "s" are sorted) _, sjoint_t = jax.lax.sort_key_val(joint_bins_and_t, joint_t, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sjoint_t, joint_t, dimension=-1) # TODO(kitaev): why does jax flag integer indices as differentiable? # If we don't call stop_gradient here, custom gradients below won't work # because the primitive functions close over "differentiable" variables. sjoint_t = jax.lax.stop_gradient(sjoint_t) undo_sort = jax.lax.stop_gradient(undo_sort) # The backward pass of gather is in general a scatter operation, but we know # we're dealing with permutations so we use gather for the backward pass # too. This custom gradient should be about 2x faster than having jax infer # one that uses scatter ops instead. def permute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, sjoint_t[:, :, None], axis=-2) def unpermute_impl(vecs): assert len(vecs.shape) == 3 return np.take_along_axis(vecs, undo_sort[:, :, None], axis=-2) @jax.custom_transforms def permute(vecs): return permute_impl(vecs) def permute_vjp(vecs): out_vecs = permute_impl(vecs) def vjpfun(grad): return (unpermute_impl(grad), ) return out_vecs, vjpfun @jax.custom_transforms def unpermute(vecs): return unpermute_impl(vecs) def unpermute_vjp(vecs): out_vecs = unpermute_impl(vecs) def vjpfun(grad): return (permute_impl(grad), ) return out_vecs, vjpfun jax.defvjp_all(permute, permute_vjp) jax.defvjp_all(unpermute, unpermute_vjp) sqk = permute(qk) sv = permute(v) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = chunk_scalars(sjoint_t) bqk = chunk_vectors(sqk) bv = chunk_vectors(sv) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bin, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. bk_extra = np.concatenate([bk[:, -1:, :, :], bk[:, :-1, :, :]], axis=1) bk = np.concatenate([bk, bk_extra], axis=2) bv_extra = np.concatenate([bv[:, -1:, :, :], bv[:, :-1, :, :]], axis=1) bv = np.concatenate([bv, bv_extra], axis=2) bkv_t_extra = np.concatenate([bkv_t[:, -1:, :], bkv_t[:, :-1, :]], axis=1) bkv_t = np.concatenate([bkv_t, bkv_t_extra], axis=2) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, :, None], bkv_t[:, :, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.broadcasted_eye(dots.dtype, dots.shape, (2, 3)) self_mask = jax.lax.tie_in(dots, self_mask) dots = dots - 32 * self_mask # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if self._hard_k > 0: top_k = np.sort(dots)[..., -self._hard_k] # Get the top-kth weight. top_k = jax.lax.stop_gradient(top_k) dots -= top_k[..., np.newaxis] # Subtract (be 0 for lower ones). dots = np.maximum(dots, 0) dots_sum = np.sum(dots, axis=-1, keepdims=True) # Sum to re-normalize. dots_logsumexp += np.log(dots_sum) # Add it to the weight. dots /= dots_sum # Re-normalize. bo = np.matmul(dots, bv) so = unchunk_vectors(bo) slogits = unchunk_vectors(dots_logsumexp) o = unpermute(so) logits = unpermute(slogits) o = np.reshape(o, (self.n_hashes, -1, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, -1, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == inputs[2].shape return out, state
def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1)
def apply_fun(params, inputs, **kwargs): return np.concatenate(inputs, axis)
def single_call(self, qk, v, buckets, hash_rng=None): # We use the same vector as both a query and a key. seqlen = qk.shape[-2] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(qk, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val( buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sqk = np.take(qk, st, axis=0) sv = np.take(v, st, axis=0) # Split off a "bin" axis so that attention only occurs within chunks. bq_t = bkv_t = np.reshape(st, (self.n_hashes * self.n_bins, -1)) bqk = np.reshape(sqk, (self.n_hashes * self.n_bins, -1, sqk.shape[-1])) bv = np.reshape(sv, (self.n_hashes * self.n_bins, -1, sv.shape[-1])) bq_buckets = bkv_buckets = np.reshape( sbuckets_and_t // seqlen, (self.n_hashes * self.n_bins, -1)) # Hashing operates on unit-length vectors. Unnormalized query vectors are # fine because they effectively provide a learnable temperature for the # attention softmax, but normalizing keys is needed so that similarity for # the purposes of attention correctly corresponds to hash locality. bq = bqk bk = self.make_unit_length(bqk) # Allow each chunk to attend within itself, and also one chunk back. Chunk # boundaries might occur in the middle of a sequence of items from the # same bucket, so this increases the chances of attending to relevant items. # TODO(kitaev): benchmark whether XLA pad operation is noticeably faster. def look_one_back(x): if len(x.shape) == 2: x_extra = np.concatenate([x[-1:, :], x[:-1, :]], axis=0) else: x_extra = np.concatenate([x[-1:, :, :], x[:-1, :, :]], axis=0) return np.concatenate([x, x_extra], axis=1) bk = look_one_back(bk) bv = look_one_back(bv) bkv_t = look_one_back(bkv_t) bkv_buckets = look_one_back(bkv_buckets) # Dot-product attention. dots = np.matmul(bq, np.swapaxes(bk, -1, -2)) / np.sqrt(bq.shape[-1]) # Causal masking mask = jax.lax.convert_element_type( jax.lax.lt(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e9 * mask # Mask out attention to self except when no other targets are available. self_mask = jax.lax.convert_element_type( jax.lax.eq(bq_t[:, :, None], bkv_t[:, None, :]), np.float32) dots = dots - 1e5 * self_mask # Mask out attention to other hash buckets. if not self._attend_across_buckets: bucket_mask = jax.lax.convert_element_type( jax.lax.ne(bq_buckets[:, :, None], bkv_buckets[:, None, :]), np.float32) dots = dots - 1e7 * bucket_mask # Don't double-count query-key pairs across multiple rounds of hashing. # There are two possible strategies here. (1) The default is to count how # many times a query-key pair is repeated, and to lower its log-prob # correspondingly at each repetition. (2) When hard_k is set, the code # instead masks all but the first occurence of each query-key pair. # TODO(kitaev): is one strategy faster or more numerically stable? if not self._allow_duplicate_attention: locs1 = undo_sort // bq_t.shape[-1] locs2 = (locs1 + 1) % (self.n_hashes * self.n_bins) if not self._attend_across_buckets: locs1 = buckets * (self.n_hashes * self.n_bins) + locs1 locs2 = buckets * (self.n_hashes * self.n_bins) + locs2 locs = np.moveaxis(np.concatenate([ np.reshape(locs1, (self.n_hashes, seqlen)), np.reshape(locs2, (self.n_hashes, seqlen)), ], 0), 0, -1) # produces shape (seqlen, 2 * self.n_hashes) slocs = np.take(locs, st, axis=0) b_locs = np.reshape( slocs, (self.n_hashes * self.n_bins, -1, 2 * self.n_hashes)) # Queries always use the primary location (based on locs1). b_locs1 = b_locs[:, :, None, :self.n_hashes] if self._hard_k > 0: range_n_hashes = jax.lax.tie_in(b_locs, np.arange(self.n_hashes)) nouse_locs = (range_n_hashes[:, None] > range_n_hashes[None, :]) nouse_locs = 2 * nouse_locs - 1 # 1 = use, -1 = don't use nouse_locs = np.reshape( np.broadcast_to(nouse_locs[:, None, :], (self.n_hashes, self.n_bins, self.n_hashes)), (self.n_hashes * self.n_bins, 1, 1, self.n_hashes)) b_locs1 = b_locs1 * nouse_locs bq_locs = np.broadcast_to( b_locs1, b_locs.shape[:2] + (2, self.n_hashes)) bq_locs = np.reshape(bq_locs, b_locs.shape) bkv_locs = look_one_back(b_locs) dup_counts = np.sum( jax.lax.convert_element_type( jax.lax.eq(bq_locs[:, :, None, :], bkv_locs[:, None, :, :]), np.float32), axis=-1) assert dup_counts.shape == dots.shape if self._hard_k > 0: dots = dots - 1e7 * jax.lax.stop_gradient(dup_counts) else: dots = dots - jax.lax.stop_gradient(np.log(dup_counts + 1e-9)) # Each query only attends to the top k most relevant keys. if self._hard_k > 0: b_top_dots = np.sort(dots)[..., -self._hard_k:] # Get the top k dots. b_top_dots = jax.lax.stop_gradient(b_top_dots) s_top_dots = np.reshape(b_top_dots, (-1, self._hard_k)) top_dots = np.take(s_top_dots, undo_sort, axis=0) merged_top_dots = np.moveaxis( np.reshape(top_dots, (self.n_hashes, seqlen, self._hard_k)), 0, -1) merged_top_dots = np.reshape(merged_top_dots, (seqlen, -1)) dots_thresh = np.sort(merged_top_dots)[:, -self._hard_k] # It's possible to compute the partition function at this point, but right # now this codepath isn't set up for backprop, and there might also be # issues computing it this way if two dot-products are exactly equal. sdots_thresh = dots_thresh[st] bdots_thresh = np.reshape(sdots_thresh, (self.n_hashes * self.n_bins, -1)) bdots_thresh = jax.lax.stop_gradient(bdots_thresh) top_k_mask = jax.lax.convert_element_type( dots < bdots_thresh[..., None], np.float32) dots = dots - 1e7 * jax.lax.stop_gradient(top_k_mask) # Softmax. dots_logsumexp = backend.logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) bo = np.matmul(dots, bv) so = np.reshape(bo, (-1, bo.shape[-1])) slogits = np.reshape(dots_logsumexp, (-1,)) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val( buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes == 1: out = o else: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - backend.logsumexp(logits, axis=0, keepdims=True)) out = np.sum(o * probs, axis=0) assert out.shape == v.shape return out