Ejemplo n.º 1
0
    def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Creates a funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: upsampling if set to True.

    Returns:
      Funnel mask.
    """

        if self._mode == 'predict':
            # We cannot generate more than one token because it contradicts
            # all autoregressive properties
            assert queries_len == 1
            mask = jnp.arange(
                self._max_len) <= (self.state // self._total_kv_pooling)
            mask = jnp.reshape(mask, (1, 1, 1, self._max_len))
            mask = jnp.repeat(mask, batch_size, axis=0)
            self.state += self._n_raw_tokens_generated
            return mask

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
Ejemplo n.º 2
0
    def forward(self, inputs):
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Ejemplo n.º 3
0
  def fbo(inputs, weights, state, slots, opt_params, rng, step, grads):
    """FBO of the layer."""
    # We need a layer pure_fn but only for inputs and weights.
    def pure_fn_without_state_and_rng(x, w):
      return layer.pure_fn(x, w, state, rng)

    # Calculate the vector-Jacobian product of the reduced pure fn.
    activations, vjp_fn, new_state = fastmath.vjp(
        pure_fn_without_state_and_rng, inputs, weights, has_aux=True)

    # In the loss layer, set gradients to 1 with the dtype of activations=loss.
    if grads is None and stats_name is not None:
      grads = jnp.ones((), dtype=activations.dtype)

    # The vjp function returns gradients with respect to inputs and weights.
    grads_inputs, grads_weights = vjp_fn(grads)

    # For non-trainable layers, return the calculated arguments.
    if _is_empty_tuple(weights):
      stats = {}
      if stats_name is not None:
        stats[stats_name] = activations
      return weights, new_state, slots, grads_inputs, stats

    # In multi-device setting, average gradients from multiple devices.
    if n_devices > 1:
      grads_weights = _average_multidevice_gradients(grads_weights)

    # Run the optimizer.
    new_weights, new_slots, stats = optimizer.tree_update(
        step, grads_weights, weights, slots, opt_params)
    if stats_name is not None:
      stats[stats_name] = activations
    return new_weights, new_state, new_slots, grads_inputs, stats
Ejemplo n.º 4
0
def _fast_inference_update_state(inputs, state):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores tensors with cached values of keys and values,
  as well as the mask and an index. To make shapes static, keys and values
  in the state are long, and the index indicates where the new keys and values
  from inputs need to be appended. Mask ensures that attention will only look
  at keys upto index.

  During update, we append new_keys and new_values to keys and values at
  position given by index. We also update mask (which starts as all-0s) to
  be 1 at the new keys positions. And we increment index by length of new keys.

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, mask, index)

  Returns:
    Updated state.
  """
    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    length = new_k.shape[1]
    (ks, vs, mask, idx) = state
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    # Mask is of shape [batch_size, 1 (for heads), length].
    new_mask = jnp.ones((mask.shape[0], mask.shape[1], length))
    mask = fastmath.dynamic_update_slice_in_dim(mask, new_mask, idx, axis=2)
    return (ks, vs, mask, idx + length)
Ejemplo n.º 5
0
    def forward(self, inputs):
        """Returns attention-computed activations.

    Args:
      inputs: A (queries, keys, values) tuple.
    """
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.is_backend(fastmath.Backend.JAX):
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Ejemplo n.º 6
0
        def loss_fbo(inputs, weights, state, slots, opt_params, rng, step):
            """FBO of the final loss layer."""

            # We need a loss layer pure_fn but only for inputs and weights.
            def loss_pure_fn_without_state_and_rng(x, w):
                return loss_layer.pure_fn(x, w, state, rng)

            # Calculate the vector-Jacobian product of the reduced loss pure fn.
            loss, vjp_fn, new_state = fastmath.vjp(
                loss_pure_fn_without_state_and_rng,
                inputs,
                weights,
                has_aux=True)

            # The vjp function returns gradients with respect to inputs and weights.
            # Since loss is scalar and there are no other layers, run it at 1.0.
            grads_inputs, grads_weights = vjp_fn(jnp.ones((),
                                                          dtype=loss.dtype))

            # In multi-device setting, average gradients from multiple devices.
            if self._n_devices > 1:
                grads_weights = _average_multidevice_gradients(grads_weights)

            # Run the loss optimizer, which is the last one since it's the last layer.
            new_weights, new_slots, stats = self._optimizers[-1].tree_update(
                step, grads_weights, weights, slots, opt_params)
            stats['loss'] = loss
            return new_weights, new_state, new_slots, grads_inputs, stats
Ejemplo n.º 7
0
def _causal_mask(length):
  # Not all backends define jnp.tril. However, using np.tril is inefficient
  # in that it creates a large global constant. TODO(kitaev): try to find an
  # alternative that works across all backends.
  if fastmath.is_backend(fastmath.Backend.JAX):
    return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)
  else:
    return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
Ejemplo n.º 8
0
 def update_mask(mask, x_times_one_minus_f):  # pylint: disable=invalid-name
     initial = jnp.ones(x_times_one_minus_f.shape[:2], dtype=jnp.float32)
     if initial.shape[1] > 1:
         updated_mask = fastmath.dynamic_update_slice_in_dim(initial != 0,
                                                             mask != 0,
                                                             1,
                                                             axis=1)
     else:
         updated_mask = initial
     return updated_mask, x_times_one_minus_f
Ejemplo n.º 9
0
    def test_pure_lsh_wrapper_non_causal_masked(self, num_weights):
        with fastmath.use_backend(fastmath.Backend.JAX):
            n_heads = 5
            batch, seqlen, d_head = 3, 32, 8
            num_weights = 2
            n_hashes = 2
            d_model = n_heads * d_head
            layer = efficient_attention.PureLSHSelfAttentionWrapper(
                n_heads=n_heads,
                d_qk=d_head,
                d_v=d_head,
                causal=False,
                masked=True,
                chunk_len=8,
                n_chunks_before=1,
                n_chunks_after=0,
                n_hashes=n_hashes,
                n_buckets=4,
                bias=False,
                pure_lsh_implementation=efficient_attention.
                PureLSHSelfAttention,
                mode='train',
                num_weights=num_weights)

            rng = jax.random.PRNGKey(0)
            rng, x_rng = jax.random.split(rng)

            input_shape = (batch, seqlen, d_model)
            x = jax.random.uniform(x_rng, input_shape, dtype=jnp.float32)
            mask = jnp.ones((batch, seqlen), dtype=jnp.int32)

            inp = (x, mask)
            w, s = layer.init(shapes.signature(inp))
            o = layer(inp)

            # Get the actual weights.
            weights = fastmath.tree_leaves(w)
            # Assert number of weights is as expected, the extra 1 is for output.
            self.assertLen(weights, num_weights + 1)

            # Assert each weight is of the expected shape.
            for i in range(num_weights + 1):
                self.assertEqual(weights[i].shape, (d_model, d_model))

            # Test that the output and the x's shape match.
            self.assertEqual(x.shape, o.shape)

            # Assert state is the shape expected.
            state = fastmath.tree_leaves(s)
            self.assertLen(state, 2)
            # buckets
            self.assertEqual(state[0].shape,
                             (batch * n_heads, n_hashes * seqlen))
            # rngs
            self.assertEqual(state[1].shape, (batch * n_heads, 2))
Ejemplo n.º 10
0
 def init_weights_and_state(self, input_signature):
   """Helper to initialize batch norm weights and state."""
   axis = self._axis
   axis = (axis,) if jnp.isscalar(axis) else axis
   input_shape = input_signature.shape
   shape = tuple(d for i, d in enumerate(input_shape) if i not in axis)
   # TODO(jonni): Should beta and gamma match the dtype in the input signature?
   beta = jnp.zeros(shape, dtype='float32') if self._center else ()
   gamma = jnp.ones(shape, dtype='float32') if self._scale else ()
   def get_stats_axis(i, d):
     if i in axis:
       return 1
     else:
       return d
   stats_shape = tuple(get_stats_axis(i, d) for i, d in enumerate(input_shape))
   running_mean = jnp.zeros(stats_shape, dtype=jnp.float32)
   running_var = jnp.ones(stats_shape, dtype=jnp.float32)
   n_batches = jnp.zeros((), dtype=jnp.int64)
   self.weights = (beta, gamma)
   self.state = (running_mean, running_var, n_batches)
Ejemplo n.º 11
0
    def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Funnel mask.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: True or False.

    Returns:
      funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.
    """

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
Ejemplo n.º 12
0
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    mask_size = q.shape[-2]
    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    return DotProductAttention(q, k, v, mask)
Ejemplo n.º 13
0
    def init_weights_and_state(self, input_signature):
        # Usually (B, W, H, C)
        shape = input_signature.shape
        num_channels = shape[-1]

        gamma = jnp.ones((num_channels, ), dtype=jnp.float32)
        beta = jnp.zeros((num_channels, ), dtype=jnp.float32)

        epsilon_l = base.EMPTY_WEIGHTS
        if self._learn_epsilon:
            epsilon_l = (self._init_learnt_epsilon, )

        self.weights = gamma, beta, epsilon_l
Ejemplo n.º 14
0
  def forward(self, inputs):
    inputs_len = inputs.shape[1]

    if self._mode == 'predict':
      # We cannot generate more than one token because it contradicts
      # all autoregressive properties
      assert inputs_len == 1

      current_token, sequence_length = calc_predict_next_token_index(
          self.state, self._total_kv_pooling, self._max_len, self._chunk_len,
          self._chunk_offset)

      mask = jnp.arange(sequence_length) <= current_token
      mask = jnp.reshape(mask, (1, sequence_length))
      self.state += self._n_raw_tokens_generated
      return mask

    if self._chunk_len is not None:
      return jnp.tril(
          jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_))

    return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
Ejemplo n.º 15
0
 def test_autoregressive_sample_transformer(self):
     model = models.Transformer(10,
                                d_model=32,
                                d_ff=64,
                                n_encoder_layers=1,
                                n_decoder_layers=1,
                                n_heads=2,
                                mode='predict')
     inputs = jnp.ones((1, 3), dtype=jnp.int32)
     model.init((shapes.signature(inputs),
                 shapes.ShapeDtype((1, 1), dtype=jnp.int32)))
     s = trainer_lib.autoregressive_sample(model,
                                           inputs=inputs,
                                           eos_id=-1,
                                           max_length=10)
     self.assertEqual(s.shape[0], 1)
     self.assertEqual(s.shape[1], 10)
Ejemplo n.º 16
0
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    # for causal attention: (Q. Kt) + M
    # mask size should be Lk x Lq
    mask_size = q.shape[-2]

    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
    # Use jnp.tril() - Lower triangle of an array and jnp.ones()
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)

    return DotProductAttention(q, k, v, mask)
Ejemplo n.º 17
0
 def init_weights_and_state(self, input_signature):
     features = input_signature.shape[-1]
     scale = jnp.ones(features, dtype=input_signature.dtype)
     bias = jnp.zeros(features, dtype=input_signature.dtype)
     self.weights = scale, bias
Ejemplo n.º 18
0
def dot_product_self_attention(q, k, v):
    mask_size = q.shape[-2]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    return DotProductAttention(q, k, v, mask)
Ejemplo n.º 19
0
def _fast_inference_update_state(inputs, state, mask_for_predict=None):
    """Updates state of a causal attention layer for fast inference.

  The layer state stores arrays with cached values of keys and values,
  as well as an index. To make shapes static, keys and values in the state are
  long, and the index indicates where the new keys and values from inputs need
  to be appended.

  During update, we append new_keys and new_values to keys and values at
  position given by index. And we increment index by length of new keys.
  We also create a mask to be 1 at appropriate positions (causal mask).

  Args:
    inputs: a triple (new_queries, new_keys, new_values)
    state: layer state with (keys, values, index)
    mask_for_predict: mask used for predict mode. This is used only in
      Terraformer.

  Returns:
    Updated state and mask to be used.
  """
    # Fast inference: run step-by-step, storing the sequence
    # of keys and values calculated so far in state.
    (_, new_k, new_v) = inputs
    if mask_for_predict is not None:
        (state_mask_for_predict, ks, vs, idx) = state
    else:
        (ks, vs, idx) = state
    length = new_k.shape[1]
    # TODO(lukaszkaiser): benchmark speed and decide if using a separate code path
    # with index_update when length == 1 is worth it.
    # Keys and values are of shape [batch_size, length, d_kv].
    ks = fastmath.dynamic_update_slice_in_dim(ks, new_k, idx, axis=1)
    vs = fastmath.dynamic_update_slice_in_dim(vs, new_v, idx, axis=1)
    k_length = ks.shape[1]

    # Mask is of shape [1, q_length, k_length].
    # Mask should be true for every pair of (query_token, key_token) such that
    # index of query_token is equal or larger to index of key_token.
    mask = (jnp.reshape(jnp.arange(k_length), (1, 1, k_length)) <= jnp.reshape(
        jnp.arange(length) + idx, (1, length, 1)))
    if mask_for_predict is None:
        return (ks, vs, idx + length), mask
    else:
        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0,
            mask_for_predict.reshape((-1)) != 0,
            0,
            axis=0)

        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0,
            jnp.ones((1, )) != 0,
            jnp.sum(mask_for_predict, dtype=jnp.int32),
            axis=0)

        state_mask_for_predict = fastmath.dynamic_update_slice_in_dim(
            state_mask_for_predict != 0, jnp.ones((1, )) != 0, idx, axis=0)
        placeholder = jnp.reshape(state_mask_for_predict != 0, (
            1,
            1,
            mask.shape[2],
        ))
        mask = mask * placeholder

        return (state_mask_for_predict, ks, vs, idx + length), mask
Ejemplo n.º 20
0
 def init_weights_and_state(self, input_signature):
     self.weights = jnp.zeros((2, 3))
     self.state = jnp.ones(input_signature.shape)
Ejemplo n.º 21
0
 def bidirectional_denominator(query_prime, key_prime):
     all_ones = jnp.ones([query_prime.shape[0]])
     ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones)
     return jnp.einsum('lbm,bm->lb', query_prime, ks_sum)