Example #1
0
    def _funnel_mask(self, batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Creates a funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: upsampling if set to True.

    Returns:
      Funnel mask.
    """

        if self._mode == 'predict':
            # We cannot generate more than one token because it contradicts
            # all autoregressive properties
            assert queries_len == 1
            mask = jnp.arange(
                self._max_len) <= (self.state // self._total_kv_pooling)
            mask = jnp.reshape(mask, (1, 1, 1, self._max_len))
            mask = jnp.repeat(mask, batch_size, axis=0)
            self.state += self._n_raw_tokens_generated
            return mask

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
Example #2
0
    def forward(self, inputs):
        """Returns attention-computed activations.

    Args:
      inputs: A (queries, keys, values) tuple.
    """
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.is_backend(fastmath.Backend.JAX):
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Example #3
0
    def forward(self, inputs):
        q, k, v = inputs

        if self._mode == 'predict':
            self.state = _fast_inference_update_state(inputs, self.state)
            (k, v, mask, _) = self.state
        else:
            mask_size = q.shape[-2]
            # Not all backends define jnp.tril. However, using np.tril is inefficient
            # in that it creates a large global constant. TODO(kitaev): try to find an
            # alternative that works across all backends.
            if fastmath.backend_name() == 'jax':
                mask = jnp.tril(jnp.ones((1, mask_size, mask_size),
                                         dtype=np.bool_),
                                k=0)
            else:
                mask = np.tril(np.ones((1, mask_size, mask_size),
                                       dtype=np.bool_),
                               k=0)

        res, dots = DotProductAttention(q,
                                        k,
                                        v,
                                        mask,
                                        dropout=self._dropout,
                                        mode=self._mode,
                                        rng=self.rng)
        if self._mode == 'viz':
            self.state = dots
        return res
Example #4
0
def _causal_mask(length):
  # Not all backends define jnp.tril. However, using np.tril is inefficient
  # in that it creates a large global constant. TODO(kitaev): try to find an
  # alternative that works across all backends.
  if fastmath.is_backend(fastmath.Backend.JAX):
    return jnp.tril(jnp.ones((1, length, length), dtype=np.bool_), k=0)
  else:
    return np.tril(np.ones((1, length, length), dtype=np.bool_), k=0)
Example #5
0
    def _funnel_mask(batch_size, keys_len, queries_len, funnel_factor,
                     is_upsampling):
        """Funnel mask.

    Args:
      batch_size: batch size.
      keys_len: keys length.
      queries_len: queries length.
      funnel_factor: funnel factor.
      is_upsampling: True or False.

    Returns:
      funnel mask.

    This function based on keys/queries lengths creates a triangle mask
    that prevents tokens from attending to positions following it.

    If funnel_factor is not equal to 1 due to funnel upsampling or
    downsampling it adjusts created mask for funnel attention
    by repeating each element funnel_factor times.

    This is because after funnel layer one token attends to funnel_factor
    different tokens in downsampling. During upsampling on the other hand
    funnel_factor tokens are attending to single token before upsampling.
    """

        if funnel_factor != 1:
            if not is_upsampling:
                mask = jnp.tril(
                    jnp.ones((queries_len, queries_len), dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-1)
            else:
                mask = jnp.tril(jnp.ones((keys_len, keys_len),
                                         dtype=jnp.bool_))
                mask = jnp.repeat(mask, funnel_factor, axis=-2)
        else:
            mask = jnp.tril(
                jnp.ones((queries_len, queries_len), dtype=jnp.bool_))

        return jnp.repeat(mask[None, None, :, :], batch_size, axis=0)
Example #6
0
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    mask_size = q.shape[-2]
    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    return DotProductAttention(q, k, v, mask)
Example #7
0
  def forward(self, inputs):
    inputs_len = inputs.shape[1]

    if self._mode == 'predict':
      # We cannot generate more than one token because it contradicts
      # all autoregressive properties
      assert inputs_len == 1

      current_token, sequence_length = calc_predict_next_token_index(
          self.state, self._total_kv_pooling, self._max_len, self._chunk_len,
          self._chunk_offset)

      mask = jnp.arange(sequence_length) <= current_token
      mask = jnp.reshape(mask, (1, sequence_length))
      self.state += self._n_raw_tokens_generated
      return mask

    if self._chunk_len is not None:
      return jnp.tril(
          jnp.ones((self._chunk_len, self._chunk_len), dtype=jnp.bool_))

    return jnp.tril(jnp.ones((inputs_len, inputs_len), dtype=jnp.bool_))
Example #8
0
def dot_product_self_attention(q, k, v):
    """ Masked dot product self attention.
    Args:
        q (jax.interpreters.xla.DeviceArray): queries.
        k (jax.interpreters.xla.DeviceArray): keys.
        v (jax.interpreters.xla.DeviceArray): values.
    Returns:
        jax.interpreters.xla.DeviceArray: masked dot product self attention tensor.
    """
    # for causal attention: (Q. Kt) + M
    # mask size should be Lk x Lq
    mask_size = q.shape[-2]

    # Creates a matrix with ones below the diagonal and 0s above. It should have shape (1, mask_size, mask_size)
    # Notice that 1's and 0's get casted to True/False by setting dtype to jnp.bool_
    # Use jnp.tril() - Lower triangle of an array and jnp.ones()
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)

    return DotProductAttention(q, k, v, mask)
Example #9
0
def dot_product_self_attention(q, k, v):
    mask_size = q.shape[-2]
    mask = jnp.tril(jnp.ones((1, mask_size, mask_size), dtype=jnp.bool_), k=0)
    return DotProductAttention(q, k, v, mask)