def run_policy( policy_and_value_net_apply, observations, lengths, weights, state, rng, action_space, ): """Runs the policy network.""" # TODO(pkozakowski): Pass the actual actions here, to enable autoregressive # action sampling. (B, T_plus_1) = observations.shape[:2] # pylint: disable=invalid-name dummy_actions = onp.zeros((B, T_plus_1 - 1) + action_space.shape, dtype=action_space.dtype) policy_input = (observations, dummy_actions) (rng, subrng) = trax_random.split(rng) (log_probs, value_preds) = policy_and_value_net_apply(policy_input, weights=weights, state=state, rng=subrng) # We need the log_probs of those actions that correspond to the last actual # time-step. index = lengths - 1 # Since we want to index using lengths. log_probs = log_probs[np.arange(B), index] value_preds = value_preds[np.arange(B), index] return (log_probs, value_preds, state, rng)
def run_policy( policy_and_value_net_apply, observations, lengths, weights, state, rng, action_space, ): """Runs the policy network and returns lps, vps for the last timestep.""" log_probs, value_preds, state, rng = run_policy_all_timesteps( policy_and_value_net_apply, observations, weights, state, rng, action_space, ) # We need the log_probs of those actions that correspond to the last actual # time-step. (B, unused_T_plus_1) = observations.shape[:2] # pylint: disable=invalid-name index = lengths - 1 # Since we want to index using lengths. log_probs = log_probs[np.arange(B), index] value_preds = value_preds[np.arange(B), index] return (log_probs, value_preds, state, rng)
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None): """Actor loss.""" # log_probab_actions_new's shape is (AB, 1, #C, #A), AB is actor batch. lp = jnp.squeeze(log_probab_actions_new, axis=1) AB, NC = actions.shape # pylint: disable=invalid-name log_probs = lp[jnp.arange(AB)[:, None], jnp.arange(NC)[None, :], actions] # TODO(afrozm): Clarify this. # log_probs are shaped (AB, #C), however advantage_weights are (AB,) return -1.0 * jnp.mean(log_probs * advantage_weights[:, None]), state
def one_hot(x, n_categories, dtype=np.float32): # pylint: disable=invalid-name """Makes a one-hot array (n+1 dims) from an int-categorical array (n dims).""" indices_less_than_n = np.arange(n_categories) if math.backend_name() == 'jax': # Work around a jax broadcasting issue. indices_less_than_n = jax.lax.tie_in(x, indices_less_than_n) return np.array(x[..., np.newaxis] == indices_less_than_n, dtype)
def forward_with_state(self, inputs, weights=layer_base.EMPTY_WEIGHTS, state=layer_base.EMPTY_STATE, rng=None, **kwargs): depth = inputs.shape[-1] if self._mode == 'predict': emb = self._get_embeddings(t=state) emb = emb[:, np.newaxis, :] state = state + 1 else: input_len = inputs.shape[-2] emb = self._get_embeddings(t=np.arange(input_len, dtype=np.int32)) # Leave batch axis as 1 for broadcasting: emb = emb[np.newaxis, :, :] emb = np.broadcast_to(emb, inputs.shape[:-1] + (3, )) # Replace the last num_features channels of input. inputs = np.concatenate([inputs[..., :-self.num_features], emb], -1) if inputs.shape[-1] > depth: logging.warning('dropping feature(s): %d down to %d', inputs.shape[-1], depth) inputs = inputs[..., -depth:] assert inputs.shape[-1] == depth, inputs.shape return inputs, state
def forward_with_state(self, x, weights, state, rng): batch_size, length = x.shape[0], x.shape[1] max_pos = min(self._bases)**self._n_digits rng1, rng2, rng3 = math.random.split(rng, 3) assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = jnp.arange(0, length)[None, :] if self._mode == 'train': # In 1% of training cases still start from 0 to be exactly as in eval. start_from_nonzero = jax.random.randint( rng1, (batch_size, ), 0, self._start_from_zero_one_in) start_from_nonzero = jnp.minimum(1, start_from_nonzero) random_start = jax.random.randint(rng2, (batch_size, ), 0, max_pos - length) random_start *= start_from_nonzero positions += random_start[:, None] res = [] for bn, base in enumerate(self._bases): pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = jnp.mod(cur_positions, base) cur_positions = cur_positions // base s = weights[bn][i] pos_embeddings.append( cur_indices.astype(jnp.float32)[:, :, None] * s) embeddings = jnp.concatenate(pos_embeddings, axis=-1) if self._mode == 'train': base_dropout = jax.random.randint(rng3, (batch_size, ), 0, self._base_dropout_one_in) base_dropout = jnp.minimum(1, base_dropout).astype(jnp.float32) embeddings *= base_dropout[:, None, None] res.append(embeddings) res = sum(res) + jnp.zeros_like(x) return jnp.concatenate([x, res], axis=-1), state
def one_hot(x, size, dtype=np.float32): # pylint: disable=invalid-name """Make a n+1 dim one-hot array from n dim int-categorical array.""" arange_size = np.arange(size) if math.backend_name() == 'jax': # Work around a jax broadcasting issue. arange_size = jax.lax.tie_in(x, arange_size) return np.array(x[..., np.newaxis] == arange_size, dtype)
def forward_unbatched(self, x, *, weights, state, update_state): del update_state if self.share_qk: w_q, w_v, w_o = weights else: w_q, w_k, w_v, w_o = weights q = np.matmul(x, w_q) k = None if not self.share_qk: k = np.matmul(x, w_k) v = np.matmul(x, w_v) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=self.share_qk) q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2])) o, _ = attend( q, k, v, q_chunk_len=self.chunk_len, kv_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG ) out = np.matmul(o, w_o) return out, state
def log_prob(self, inputs, point): # Flatten the prefix dimensions for easy indexing. flat_point = np.reshape(point, -1) flat_inputs = np.reshape(inputs, (point.size, -1)) flat_log_probs = flat_inputs[np.arange(point.size), flat_point.astype(int)] return np.reshape(flat_log_probs, point.shape)
def forward(self, inputs, weights): state = self.state depth = inputs.shape[-1] if self._mode == 'predict': emb = self._get_embeddings(t=state) emb = emb[:, jnp.newaxis, :] state = state + 1 else: input_len = inputs.shape[-2] emb = self._get_embeddings( t=jnp.arange(input_len, dtype=jnp.int32)) # Leave batch axis as 1 for broadcasting: emb = emb[jnp.newaxis, :, :] emb = jnp.broadcast_to(emb, inputs.shape[:-1] + (3, )) # Replace the last num_features channels of input. inputs = jnp.concatenate([inputs[..., :-self.num_features], emb], -1) if inputs.shape[-1] > depth: logging.warning('dropping feature(s): %d down to %d', inputs.shape[-1], depth) inputs = inputs[..., -depth:] assert inputs.shape[-1] == depth, inputs.shape self.state = state return inputs
def actor_loss(actions, advantage_weights, log_probab_actions_new, state=None): """Actor loss.""" lp = np.squeeze(log_probab_actions_new) b = len(lp) log_probs = np.squeeze(lp[np.arange(b)[np.newaxis, :], actions]) return -1.0 * np.mean(log_probs * advantage_weights), state
def F(vec_e, vec_d, mask_e, mask_d): # pylint: disable=invalid-name L1 = mask_e.shape[1] L2 = mask_d.shape[1] # pylint: enable=invalid-name # [-(L1+L2), -L2) but with padding 0-ed out - (B, L1). mask_e_key = jnp.arange(-(L1 + L2), -L2) * mask_e # [-L2,0) but with padding 0-ed out - (B, L2). mask_d_key = jnp.arange(-L2, 0) * mask_d # Shape (B, L1+L2, H) enc_dec_concat = jnp.concatenate([vec_e, vec_d], axis=1) # Shape (B, L1+L2) mask_concat = jnp.concatenate([mask_e_key, mask_d_key], axis=1) # Make `mask_concat` the same shape as `enc_dec_concat` mask_concat = ( mask_concat[..., jnp.newaxis] + jnp.zeros_like(enc_dec_concat, dtype=jnp.int32)) # Sort on `mask_concat` so padding with key=0 goes to the right end, axis=1. _, enc_dec_pad = math.sort_key_val(mask_concat, enc_dec_concat, 1) return enc_dec_pad
def hash_vectors(self, vecs, rng): # See https://arxiv.org/pdf/1509.02897.pdf # We sample a different random rotation for each round of hashing to # decrease the probability of hash misses. if isinstance(self.n_buckets, int): assert self.n_buckets % 2 == 0 rot_size = self.n_buckets n_buckets = self.n_buckets else: # Factorize the hash if self.n_buckets is a list or tuple rot_size, n_buckets = 0, 1 for factor in self.n_buckets: assert factor % 2 == 0 rot_size += factor n_buckets *= factor rotations_shape = (vecs.shape[-1], self.n_hashes, rot_size // 2) rng = jax.lax.stop_gradient(jax.lax.tie_in(vecs, rng)) random_rotations = jax.random.normal(rng, rotations_shape).astype('float32') rotated_vecs = np.einsum('tf,fhb->htb', vecs, random_rotations) if isinstance(self.n_buckets, int) or len(self.n_buckets) == 1: rotated_vecs = np.concatenate([rotated_vecs, -rotated_vecs], axis=-1) buckets = np.argmax(rotated_vecs, axis=-1) else: # Get the buckets for them and combine. buckets, cur_sum, cur_product = None, 0, 1 for factor in self.n_buckets: rv = rotated_vecs[..., cur_sum:cur_sum + (factor // 2)] cur_sum += factor // 2 rv = np.concatenate([rv, -rv], axis=-1) if buckets is None: buckets = np.argmax(rv, axis=-1) else: buckets += cur_product * np.argmax(rv, axis=-1) cur_product *= factor # buckets is now (self.n_hashes, seqlen). Next we add offsets so that # bucket numbers from different hashing rounds don't overlap. offsets = jax.lax.tie_in(buckets, np.arange(self.n_hashes)) offsets = np.reshape(offsets * n_buckets, (-1, 1)) buckets = np.reshape(buckets + offsets, (-1, )) return buckets
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" assert math.backend_name() == 'jax', ( 'JAX backend is required to use the predict mode.') for x in inputs: assert x.shape[1] == 1, ( 'In predict mode the input sequence must be of length 1.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, seq_indices) = state batch_indices = np.arange(ks.shape[0]) ks = jax.ops.index_update(ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :]) vs = jax.ops.index_update(vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :]) mask = jax.ops.index_update(mask, jax.ops.index[batch_indices, :, seq_indices], 1) return (ks, vs, mask, seq_indices + 1)
def forward_unbatched(self, x, mask=None, *, weights, state, update_state): del update_state if self.share_qk: w_q, w_v, w_o = weights else: w_q, w_k, w_v, w_o = weights q = np.matmul(x, w_q) k = None if not self.share_qk: k = np.matmul(x, w_k) v = np.matmul(x, w_v) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=self.share_qk, masked=self.masked) q_info = kv_info = jax.lax.tie_in(x, np.arange(q.shape[-2])) assert (mask is not None) == self.masked if self.masked: # mask is a boolean array (True means "is valid token") ones_like_mask = jax.lax.tie_in(x, np.ones_like(mask, dtype=np.int32)) kv_info = kv_info * np.where(mask, ones_like_mask, -ones_like_mask) o, _ = attend( q, k, v, q_chunk_len=self.chunk_len, kv_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, kv_info=kv_info, dropout=self.attention_dropout, rng=None, # TODO(kitaev): support RNG ) out = np.matmul(o, w_o) return out, state
def forward_with_state(self, x, weights=layer_base.EMPTY_WEIGHTS, state=layer_base.EMPTY_STATE, rng=None, **kwargs): length = np.shape(x)[1] max_pos = self._base**self._n_digits assert length < max_pos, 'length (%d) >= max_pos (%d)' % (length, max_pos) positions = np.arange(0, length) if self._mode == 'train': positions += jax.random.randint(rng, (), 0, max_pos - length) pos_embeddings = [] cur_positions = positions for i in range(self._n_digits): cur_indices = np.mod(cur_positions, self._base) cur_positions //= self._base pos_embeddings.append(np.take(weights[i], cur_indices, axis=0)) embeddings = np.concatenate(pos_embeddings, axis=-1) return (x + embeddings[None, :, :], state)
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference.""" if math.backend_name() != 'jax': raise ValueError(f'JAX backend is required in predict mode, but found ' f'backend ({math.backend_nameO()}).') for x in inputs: if x.shape[1] != 1: raise ValueError(f'In predict mode, input sequence must have length 1, ' f'instead has length {x.shape[1]}.') # Fast inference: run with only 1 query in each step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs (ks, vs, mask, seq_indices) = state batch_indices = jnp.arange(ks.shape[0]) ks = jax.ops.index_update( ks, jax.ops.index[batch_indices, seq_indices, :], new_k[:, 0, :]) vs = jax.ops.index_update( vs, jax.ops.index[batch_indices, seq_indices, :], new_v[:, 0, :]) mask = jax.ops.index_update( mask, jax.ops.index[batch_indices, :, seq_indices], 1) return (ks, vs, mask, seq_indices + 1)
def test_batch_norm(self): input_shape = (2, 3, 4) input_dtype = np.float32 input_signature = ShapeDtype(input_shape, input_dtype) eps = 1e-5 inp1 = np.reshape(np.arange(np.prod(input_shape), dtype=input_dtype), input_shape) m1 = 11.5 # Mean of this random input. v1 = 47.9167 # Variance of this random input. layer = normalization.BatchNorm(axis=(0, 1, 2)) _, _ = layer.init(input_signature) state = layer.state onp.testing.assert_allclose(state[0], 0) onp.testing.assert_allclose(state[1], 1) self.assertEqual(state[2], 0) out = layer(inp1) state = layer.state onp.testing.assert_allclose(state[0], m1 * 0.001) onp.testing.assert_allclose(state[1], 0.999 + v1 * 0.001, rtol=1e-6) self.assertEqual(state[2], 1) onp.testing.assert_allclose(out, (inp1 - m1) / np.sqrt(v1 + eps), rtol=1e-6)
def threefry_2x32_prange(key, lo: int = 0, hi: int = 2): """Splits a key into a stream of random keys. This uses the little-endian counter mode. Args: key: uint32[2] the key to split lo: the range to start extracting from hi: the range to stop extracting from Returns: keys: uint32[hi - lo, 2] the split keys """ if not (key.shape == (2, ) and key.dtype == np.uint32): raise ValueError('key must be uint32[2]') if not hi < 2**32: # You shouldn't really be using more than half the key size anyways. raise NotImplementedError('only 32-bit sizes are supported') # Create a 64-bit counter: i_lo = np.arange(lo, hi, dtype=np.uint32) i_hi = np.zeros_like(i_lo) i = np.stack([i_lo, i_hi], axis=-1) return threefry_2x32_prf(key, i)
def top_k(x, k): """Select the top k slices from the last dimension.""" bcast_idxs = jnp.broadcast_to(np.arange(x.shape[-1]), x.shape) sorted_vals, sorted_idxs = lax.sort_key_val(x, bcast_idxs) # TODO(levskaya): use lax.slice here instead to benefit from XLA optimization return sorted_vals[..., -k:], sorted_idxs[..., -k:]
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) def unsort_for_output_impl(so, slogits): o = np.take(so, undo_sort, axis=0) # Sorting is considerably faster than gather, but first we need to get the # XLA compiler to abandon the idea of fusing this sort with the input sort # (which introduces a computation cycle and leads to a crash). # TODO(kitaev): remove "sticker_" variable if XLA is fixed. sticker_ = sticker + jax.lax.convert_element_type( slogits[0] > 0, sticker.dtype) _, logits = jax.lax.sort_key_val(sticker_, slogits, dimension=-1) return o, logits def unsort_for_output_vjp(so, slogits): """Custom gradient for unsort_for_output.""" so = jax.lax.stop_gradient(so) slogits = jax.lax.stop_gradient(slogits) o, logits = unsort_for_output_impl(so, slogits) def vjpfun(o_logits_grads): so_grad = np.take(o_logits_grads[0], sticker, axis=0) # TODO(kitaev): this exists to match the forward pass, but I'm not sure # if it's actually required. buckets_and_t_ = buckets_and_t + jax.lax.convert_element_type( o_logits_grads[1][0] > 0, buckets_and_t.dtype) _, slogits_grad = jax.lax.sort_key_val(buckets_and_t_, o_logits_grads[1], dimension=-1) return (so_grad, slogits_grad) return (o, logits), vjpfun unsort_for_output = jax.custom_transforms(unsort_for_output_impl) jax.defvjp_all(unsort_for_output, unsort_for_output_vjp) o, logits = unsort_for_output_impl(so, slogits) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state
def attend( q, k=None, v=None, q_chunk_len=None, kv_chunk_len=None, n_chunks_before=0, n_chunks_after=0, mask_fn=None, q_info=None, kv_info=None, dropout=0.0, rng=None, ): """Dot-product attention, with optional chunking and/or masking. Args: q: Query vectors, shape [q_len, d_qk] k: Key vectors, shape [kv_len, d_qk]; or None v: Value vectors, shape [kv_len, d_v] q_chunk_len: Set to non-zero to enable chunking for query vectors kv_chunk_len: Set to non-zero to enable chunking for key/value vectors n_chunks_before: Number of adjacent previous chunks to attend to n_chunks_after: Number of adjacent subsequent chunks to attend to mask_fn: TODO(kitaev) doc q_info: Query-associated metadata for masking kv_info: Key-associated metadata for masking dropout: Dropout rate rng: RNG for dropout Returns: A tuple (output, dots_logsumexp). The output has shape [q_len, d_v], and dots_logsumexp has shape [q_len]. The logsumexp of the attention probabilities is useful for combining multiple rounds of attention (as in LSH attention). """ assert v is not None share_qk = (k is None) if q_info is None: q_info = np.arange(q.shape[-2]) if kv_info is None and not share_qk: kv_info = np.arange(v.shape[-2]) # Split q/k/v into chunks along the time axis, if desired. if q_chunk_len is not None: q = np.reshape(q, (-1, q_chunk_len, q.shape[-1])) q_info = np.reshape(q_info, (-1, q_chunk_len)) if share_qk: assert kv_chunk_len is None or kv_chunk_len == q_chunk_len k = q kv_chunk_len = q_chunk_len kv_info = q_info elif kv_chunk_len is not None: k = np.reshape(k, (-1, kv_chunk_len, k.shape[-1])) kv_info = np.reshape(kv_info, (-1, kv_chunk_len)) if kv_chunk_len is not None: v = np.reshape(v, (-1, kv_chunk_len, v.shape[-1])) if share_qk: k = length_normalized(k) k = k / np.sqrt(k.shape[-1]) # Optionally include adjacent chunks. if q_chunk_len is not None or kv_chunk_len is not None: assert q_chunk_len is not None and kv_chunk_len is not None else: assert n_chunks_before == 0 and n_chunks_after == 0 k = look_adjacent(k, n_chunks_before, n_chunks_after) v = look_adjacent(v, n_chunks_before, n_chunks_after) kv_info = look_adjacent(kv_info, n_chunks_before, n_chunks_after) # Dot-product attention. dots = np.matmul(q, np.swapaxes(k, -1, -2)) # Masking if mask_fn is not None: dots = mask_fn(dots, q_info[..., :, None], kv_info[..., None, :]) # Softmax. dots_logsumexp = logsumexp(dots, axis=-1, keepdims=True) dots = np.exp(dots - dots_logsumexp) if dropout > 0.0: assert rng is not None # Dropout is broadcast across the bin dimension dropout_shape = (dots.shape[-2], dots.shape[-1]) # TODO(kitaev): verify that tie-in is safe to remove (in light of jax fix) keep_prob = jax.lax.tie_in(dots, 1.0 - dropout) keep = jax.random.bernoulli(rng, keep_prob, dropout_shape) multiplier = keep.astype(dots.dtype) / jax.lax.tie_in(keep, keep_prob) dots = dots * multiplier # The softmax normalizer (dots_logsumexp) is used by multi-round LSH attn. out = np.matmul(dots, v) out = np.reshape(out, (-1, out.shape[-1])) dots_logsumexp = np.reshape(dots_logsumexp, (-1, )) return out, dots_logsumexp
def forward_unbatched(self, x, *, weights, state, update_state): w_q, w_v, w_o = weights q = np.matmul(x, w_q) v = np.matmul(x, w_v) if update_state: _, old_rng = state rng = jax.random.fold_in(old_rng, 0) hash_rng = jax.random.fold_in(rng, 1) buckets = self.hash_vectors(q, hash_rng) state = (buckets, rng) else: buckets, rng = state rng = jax.random.fold_in(rng, 2) seqlen = x.shape[0] assert int(buckets.shape[0]) == self.n_hashes * seqlen ticker = jax.lax.tie_in(x, np.arange(self.n_hashes * seqlen)) buckets_and_t = seqlen * buckets + (ticker % seqlen) buckets_and_t = jax.lax.stop_gradient(buckets_and_t) # Hash-based sort ("s" at the start of variable names means "sorted") sbuckets_and_t, sticker = jax.lax.sort_key_val(buckets_and_t, ticker, dimension=-1) _, undo_sort = jax.lax.sort_key_val(sticker, ticker, dimension=-1) sbuckets_and_t = jax.lax.stop_gradient(sbuckets_and_t) sticker = jax.lax.stop_gradient(sticker) undo_sort = jax.lax.stop_gradient(undo_sort) st = (sticker % seqlen) sq = np.take(q, st, axis=0) sv = np.take(v, st, axis=0) mask_fn = functools.partial(mask_self_attention, causal=self.causal, exclude_self=True) q_info = st so, slogits = attend( sq, k=None, v=sv, q_chunk_len=self.chunk_len, n_chunks_before=self.n_chunks_before, n_chunks_after=self.n_chunks_after, mask_fn=mask_fn, q_info=q_info, dropout=self.attention_dropout, rng=rng, ) # np.take(so, undo_sort, axis=0); np.take(slogits, undo_sort, axis=0) would # also work, but these helpers include performance optimizations for TPU. o = permute_via_gather(so, undo_sort, sticker, axis=0) logits = permute_via_sort(slogits, sticker, buckets_and_t, axis=-1) if self.n_hashes > 1: o = np.reshape(o, (self.n_hashes, seqlen, o.shape[-1])) logits = np.reshape(logits, (self.n_hashes, seqlen, 1)) probs = np.exp(logits - logsumexp(logits, axis=0, keepdims=True)) o = np.sum(o * probs, axis=0) assert o.shape == (seqlen, w_v.shape[-1]) out = np.matmul(o, w_o) return out, state