Exemple #1
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores tensors with cached values of keys and values,
  as well as the mask and an index. To make shapes static, keys and values
  in the state are long, and the index indicates where the new keys and values
  from inputs need to be appended. Mask ensures that attention will only look
  at keys upto index.

  During update, we append new_keys and new_values to keys and values at
  position given by index. We also update mask (which starts as all-0s) to
  be 1 at the new keys positions. And we increment index by length of new keys.

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, mask, index)

  Returns:
    Updated state.
  """
    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    length = new_k.shape[1]
    (ks, vs, mask, idx) = state
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    # Mask is of shape [batch_size, 1 (for heads), length].
    new_mask = jnp.ones((mask.shape[0], mask.shape[1], length))
    mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2)
    return (ks, vs, mask, idx + length)
Exemple #2
0
    def update_state(self, inputs):
        cache, idx = self.state
        cache = fastmath.dynamic_update_slice_in_dim(
            cache,
            inputs, (idx + self._shift) % (2 * self._total_kv_pooling),
            axis=1)

        if self._sliding:
            cache = fastmath.dynamic_update_slice_in_dim(
                cache,
                inputs, (idx + self._total_kv_pooling * 2 - 1) %
                (2 * self._total_kv_pooling),
                axis=1)

        if self._sliding:
            left_index = idx % self._total_kv_pooling
        else:
            left_index = (idx - (idx % self._total_kv_pooling)) % \
                         (2 * self._total_kv_pooling)

        output = fastmath.dynamic_slice(
            cache, [0, left_index, 0],
            [cache.shape[0], self._total_kv_pooling, cache.shape[2]])

        self.state = cache, idx + self._n_raw_tokens_generated
        return output
Exemple #3
0
 def update_mask(mask, x_times_one_minus_f):  # pylint: disable=invalid-name
     initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)
     if initial.shape[1] > 1:
         updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0,
                                                             mask != 0,
                                                             1,
                                                             axis=1)
     else:
         updated_mask = initial
     return updated_mask, x_times_one_minus_f
Exemple #4
0
  def _fast_inference_update_state(self, inputs, state):
    """Updates state of a causal attention layer for fast inference.

    The layer state stores arrays with cached values of keys and values,
    as well as an index. To make shapes static, keys and values in the state are
    long, and the index indicates where the new keys and values from inputs need
    to be appended.

    During update, we append new_keys and new_values to keys and values at
    position given by index. And we increment index by length of new keys.
    We also create a mask to be 1 at appropriate positions (causal mask).

    Args:
      inputs: a double (new_keys, new_values)
      state: layer state with (keys, values, index)
    """
    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    new_k, new_v = inputs
    length = new_k.shape[1]
    (ks, vs, idx) = state

    # We cannot generate more than one token because it contradicts
    # all autoregressive properties
    assert length == 1

    new_index = idx // self._total_kv_pooling

    if self._chunk_len is not None:
      if self._chunk_offset != 0:
        new_index -= self._chunk_offset * (new_index >= self._chunk_offset)

      new_index = new_index % self._chunk_len

    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, new_index, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, new_index, axis=1)

    self.state = ks, vs, idx + self._n_raw_tokens_generated
Exemple #5
0
def _fast_inference_update_state(inputs, state):
  """Updates state of a causal attention layer for fast inference.

  The layer state stores arrays with cached values of keys and values,
  as well as an index. To make shapes static, keys and values in the state are
  long, and the index indicates where the new keys and values from inputs need
  to be appended.

  During update, we append new_keys and new_values to keys and values at
  position given by index. And we increment index by length of new keys.
  We also create a mask to be 1 at appropriate positions (causal mask).

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, index)

  Returns:
    Updated state and mask to be used.
  """
  # Fast inference: run step-by-step, storing the sequence
  # of keys and values calculated so far in state.
  (_, new_k, new_v) = inputs
  length = new_k.shape[1]
  (ks, vs, idx) = state
  # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
  # with index_update when length == 1 is worth it.
  # Keys and values are of shape [batch_size, length, d_kv].
  ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
  vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
  k_length = ks.shape[1]

  # Mask is of shape [1, q_length, k_length].
  # Mask should be true for every pair of (query_token, key_token) such that
  # index of query_token is equal or larger to index of key_token.
  mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length))
          <= jnp.reshape(jnp.arange(length) + idx, (1, length, 1)))

  return (ks, vs, idx + length), mask