Example #1
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    if mask is not None:  # The 'None' in this line does not need to be replaced
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)
    attention = jnp.matmul(dots, value)

    return attention
Example #2
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new per-head activations via scaled dot-product attention.

  This function is the core of the attention mechanism. Given per-head
  ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it:

    - computes the scaled dot product of each Q-K pair;
    - applies ``mask`` to screen out positions that come from padding tokens
      (indicated by 0 value);
    - [in ``'train'`` mode] applies dropout to Q-K dot products;
    - computes Q-K attention strengths using a per-query softmax of the Q-K dot
      products; and
    - for each query position, combines V vectors according to the Q-K
      attention strengths.

    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention strengths.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored. Applies only in ``'train'``
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

    Tuple of (activations, attn_strengths), where activations are new per-head
    activation vectors and attn_strengths is a matrix of per-head attention
  if dropout >= 1.0:
    raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.')

  d_feature = queries.shape[-1]

  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    dots = jnp.where(mask,
                     jnp.full_like(dots, -1e9))
  attn_strengths = (
      jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)))
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape)
    attn_strengths = jnp.where(keep,
                               attn_strengths / (1.0 - dropout),
  activations = jnp.matmul(attn_strengths, values).astype(jnp.float32)
  attn_strengths = attn_strengths.astype(jnp.float32)
  return activations, attn_strengths
Example #3
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
  """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for dropout applied to attention strengths
        (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  d_feature = queries.shape[-1]
  dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature)
  if mask is not None:
    # TODO(kitaev): workaround for https://github.com/google/jax/issues/850
    # We must ensure that both mask and the -1e9 constant have a data dependency
    # on the input. Broadcasted copies of these use a lot of memory, so they
    # should be computed at runtime (rather than being global constants).
    if fastmath.is_backend(fastmath.Backend.JAX):
      mask = jax.lax.tie_in(dots, mask)
    # JAX's `full_like` already ties in -1e9 to dots.
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
  # Softmax.
  dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
  if dropout >= 1.0:
    raise ValueError('Dropout rates must be lower than 1.')
  if dropout is not None and dropout > 0.0 and mode == 'train':
    keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
    dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
  out = jnp.matmul(dots, values)
  out = out.astype(jnp.float32)
  dots = dots.astype(jnp.float32)
  return out, dots
Example #4
def DotProductAttention(queries, keys, values, pos_emb, context_bias,
                        location_bias, mask, separate_cls, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    pos_emb: Per-head activations representing positional embeddings.
    context_bias: Global context bias from Transformer XL's attention.
    location_bias: Global location bias from Transformer XL's attention.
    mask: Mask that distinguishes positions with real content vs. padding.
    separate_cls: True/False if we separate_cls in calculations.
    dropout: Probabilistic rate for dropout applied to attention strengths
      (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).

    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
    d_feature = queries.shape[-1]
    keys_len, queries_len = keys.shape[-2], queries.shape[-2]
    funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len)

    ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys)
    bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb)

    if mode != 'predict':
        bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling)

    if separate_cls:
        # Masking out location part of attention for cls token
        bd = bd.at[:, :, :, 0].set(0)
        bd = bd.at[:, :, 0, :].set(0)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Example #5
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng):
    """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head ``queries`` and
    - applies ``mask`` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head ``values`` vectors.

    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probababilistic rate for attention dropout, which overrides
        (sets to zero) some attention strengths derived from query-key
        matching. As a result, on a given forward pass, some value vectors
        don't contribute to the output, analogous to how regular dropout can
        cause some node activations to be ignored.
    mode: One of ``'train'``, ``'eval'``, or ``'predict'``.
    rng: Single-use random number generator (JAX PRNG key).

    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
    d_feature = queries.shape[-1]
    dots = jnp.matmul(queries, jnp.swapaxes(keys, -1,
                                            -2)) / jnp.sqrt(d_feature)
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))
    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
        raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
        keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
        dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))
    out = jnp.matmul(dots, values)
    out = out.astype(jnp.float32)
    dots = dots.astype(jnp.float32)
    return out, dots
Example #6
def DotProductAttention(query, key, value, mask):
    assert query.shape[-1] == key.shape[-1] == value.shape[-1]

    depth = query.shape[-1]

    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(
        depth)  # Part of dot product formula

    # Apply mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Rest of dot product attention formula
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    dots = jnp.exp(dots - logsumexp)

    attention = jnp.matmul(dots, value)

    return attention
Example #7
def DotProductAttention(query, key, value, mask):
    """Dot product self-attention.
        query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d)
        key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d)
        value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k
        mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k)

        jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k)

    assert query.shape[-1] == key.shape[-1] == value.shape[
        -1], "Embedding dimensions of q, k, v aren't all the same"

    # scaling down (Q. K) dot product with square root of depth
    depth = query.shape[-1]

    # Calculate scaled query key dot product according to formula above
    dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth)

    # Apply the mask
    if mask is not None:
        dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax formula implementation
    # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers
    logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True)

    # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots)
    dots = jnp.exp(dots - logsumexp)

    # Multiply dots by value to get self-attention
    # Use jnp.matmul()
    attention = jnp.matmul(dots, value)

    return attention
Example #8
def log_gaussian_diag_pdf(x, mu, diag_sigma):  # pylint: disable=invalid-name
    """Returns `log N(x | mu, eye(diag_sigma))`.

    x: <tbd>
    mu: <tbd>
    diag_sigma: <tbd>
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    b = jnp.sum(jnp.log(diag_sigma), axis=-1)
    y = x - mu / diag_sigma
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Example #9
def log_gaussian_pdf(x, mu, sigma):  # pylint: disable=invalid-name
    """Returns `log N(x | mu, sigma)`.

    x: <tbd>
    mu: <tbd>
    sigma: <tbd>
    a = mu.shape[-1] * jnp.log(2 * jnp.pi)
    _, b = jnp.linalg.slogdet(sigma)
    y = jnp.linalg.solve(sigma, x - mu)
    y = jnp.expand_dims(y, axis=-1)
    xm = jnp.expand_dims(x - mu, axis=-2)
    c = jnp.matmul(xm, y)
    c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1)
    return -0.5 * (a + b + c)
Example #10
def DotProductAttention(queries, keys, values, pos_emb, context_bias,
                        location_bias, mask, dropout, mode, rng, chunk_len,
  """Computes new activations via masked attention-weighted sum of values.

  This function is the core of the attention mechanism. It:
    - computes per-head attention weights from per-head `queries` and `keys`,
    - applies `mask` to screen out positions that come from padding tokens,
    - optionally applies dropout to attention weights, and
    - uses attention weights to combine per-head `values` vectors.

    queries: Per-head activations representing attention queries.
    keys: Per-head activations representing attention keys.
    values: Per-head activations to be combined by computed attention weights.
    pos_emb: Per-head activations representing positional embeddings.
    context_bias: Global context bias from Transformer XL's attention.
    location_bias: Global location bias from Transformer XL's attention.
    mask: Mask that distinguishes positions with real content vs. padding.
    dropout: Probabilistic rate for dropout applied to attention strengths
      (based on query-key pairs) before applying them to values.
    mode: One of `'train'`, `'eval'`, or `'predict'`.
    rng: Single-use random number generator (JAX PRNG key).
    chunk_len (optional): Number of tokens per chunk. Setting this option will
      enable chunked attention.
    chunk_offset (optional): Offset for shifting chunks, for shifted chunked

    Per-head activations resulting from masked per-head attention-weighted
    sum of per-head values.
  batch_size, n_heads, original_l, d_feature = queries.shape

  def _calc_attn_scores(q, k):
    ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k)
    bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb)

    if mode != 'predict':
      bd = _fast_matrix_shift(bd)

    dots = (ac + bd) / jnp.sqrt(d_feature)
    dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9))

    # Softmax.
    dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))
    if dropout >= 1.0:
      raise ValueError('Dropout rates must be lower than 1.')
    if dropout is not None and dropout > 0.0 and mode == 'train':
      keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape)
      dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots))

    return dots

  if chunk_len is None or mode == 'predict':
    full_dots = _calc_attn_scores(queries, keys)
    out = jnp.matmul(full_dots, values)
    assert original_l % chunk_len == 0 and original_l >= chunk_len

    def chunk_split(v):
      total_len = v.shape[2]
      assert total_len % chunk_len == 0
      n_chunks = total_len // chunk_len

      chunked_shape = (batch_size, n_heads, n_chunks, chunk_len, d_feature)
      v = jnp.reshape(v, chunked_shape)
      v = v.swapaxes(1, 2)
      return jnp.reshape(v,
                         (batch_size * n_chunks, n_heads, chunk_len, d_feature))

    def chunk_join(v, total_len=original_l):
      assert total_len % chunk_len == 0
      n_chunks = total_len // chunk_len
      swapped_shape = (batch_size, n_chunks, n_heads, chunk_len, d_feature)
      v = jnp.reshape(v, swapped_shape)
      v = v.swapaxes(1, 2)
      return jnp.reshape(v, (batch_size, n_heads, total_len, d_feature))

    if chunk_offset == 0:
      queries, keys, values = map(chunk_split, [queries, keys, values])
      chunked_dots = _calc_attn_scores(queries, keys)
      chunked_result = jnp.matmul(chunked_dots, values)
      out = chunk_join(chunked_result)
      assert chunk_len > chunk_offset
      last_chunk_len = chunk_len - chunk_offset

      def split_along_l(v, mid_start, mid_end, end):
        pre = jnp.take(v, indices=range(mid_start), axis=2)
        mid = jnp.take(v, indices=range(mid_start, mid_end), axis=2)
        post = jnp.take(v, indices=range(mid_end, end), axis=2)
        return pre, mid, post

      def pad_to_chunk_len(v):
        width = [(0, 0)] * v.ndim
        width[2] = (0, chunk_len - v.shape[2])
        return jnp.pad(v, width, mode='constant', constant_values=0.0)

      def pad_borders(v):
        total_len = v.shape[2]
        pre, mid, post = split_along_l(v, chunk_offset,
                                       total_len - last_chunk_len, total_len)
        pre, post = map(pad_to_chunk_len, [pre, post])
        return jnp.concatenate([pre, mid, post], axis=2)

      def unpad_borders(v):
        padded_total_len = v.shape[2]
        assert padded_total_len == original_l + chunk_len
        pre_padded, mid, post_padded = split_along_l(
            v, chunk_len, padded_total_len - chunk_len, padded_total_len)
        pre = jnp.take(pre_padded, indices=range(chunk_offset), axis=2)
        post = jnp.take(post_padded, indices=range(last_chunk_len), axis=2)
        return jnp.concatenate([pre, mid, post], axis=2)

      queries, keys, values = map(lambda x: chunk_split(pad_borders(x)),
                                  [queries, keys, values])
      permuted_dots = _calc_attn_scores(queries, keys)
      permuted_out = chunk_join(
          jnp.matmul(permuted_dots, values), total_len=original_l + chunk_len)

      out = unpad_borders(permuted_out)

  out = out.astype(jnp.float32)
  return out, None  # We don't store full dots matrix