def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores tensors with cached values of keys and values, as well as the mask and an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. Mask ensures that attention will only look at keys upto index. During update, we append new_keys and new_values to keys and values at position given by index. We also update mask (which starts as all-0s) to be 1 at the new keys positions. And we increment index by length of new keys. Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, mask, index) Returns: Updated state. """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, mask, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) # Mask is of shape [batch_size, 1 (for heads), length]. new_mask = jnp.ones((mask.shape[0], mask.shape[1], length)) mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2) return (ks, vs, mask, idx + length)
def update_state(self, inputs): cache, idx = self.state cache = fastmath.dynamic_update_slice_in_dim( cache, inputs, (idx + self._shift) % (2 * self._total_kv_pooling), axis=1) if self._sliding: cache = fastmath.dynamic_update_slice_in_dim( cache, inputs, (idx + self._total_kv_pooling * 2 - 1) % (2 * self._total_kv_pooling), axis=1) if self._sliding: left_index = idx % self._total_kv_pooling else: left_index = (idx - (idx % self._total_kv_pooling)) % \ (2 * self._total_kv_pooling) output = fastmath.dynamic_slice( cache, [0, left_index, 0], [cache.shape[0], self._total_kv_pooling, cache.shape[2]]) self.state = cache, idx + self._n_raw_tokens_generated return output
def update_mask(mask, x_times_one_minus_f): # pylint: disable=invalid-name initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32) if initial.shape[1] > 1: updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0, mask != 0, 1, axis=1) else: updated_mask = initial return updated_mask, x_times_one_minus_f
def _fast_inference_update_state(self, inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores arrays with cached values of keys and values, as well as an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. During update, we append new_keys and new_values to keys and values at position given by index. And we increment index by length of new keys. We also create a mask to be 1 at appropriate positions (causal mask). Args: inputs: a double (new_keys, new_values) state: layer state with (keys, values, index) """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. new_k, new_v = inputs length = new_k.shape[1] (ks, vs, idx) = state # We cannot generate more than one token because it contradicts # all autoregressive properties assert length == 1 new_index = idx // self._total_kv_pooling if self._chunk_len is not None: if self._chunk_offset != 0: new_index -= self._chunk_offset * (new_index >= self._chunk_offset) new_index = new_index % self._chunk_len # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, new_index, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, new_index, axis=1) self.state = ks, vs, idx + self._n_raw_tokens_generated
def _fast_inference_update_state(inputs, state): """Updates state of a causal attention layer for fast inference. The layer state stores arrays with cached values of keys and values, as well as an index. To make shapes static, keys and values in the state are long, and the index indicates where the new keys and values from inputs need to be appended. During update, we append new_keys and new_values to keys and values at position given by index. And we increment index by length of new keys. We also create a mask to be 1 at appropriate positions (causal mask). Args: inputs: a triple (new_queries, new_keys, new_values) state: layer state with (keys, values, index) Returns: Updated state and mask to be used. """ # Fast inference: run step-by-step, storing the sequence # of keys and values calculated so far in state. (_, new_k, new_v) = inputs length = new_k.shape[1] (ks, vs, idx) = state # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path # with index_update when length == 1 is worth it. # Keys and values are of shape [batch_size, length, d_kv]. ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1) vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1) k_length = ks.shape[1] # Mask is of shape [1, q_length, k_length]. # Mask should be true for every pair of (query_token, key_token) such that # index of query_token is equal or larger to index of key_token. mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1))) return (ks, vs, idx + length), mask