def DotProductAttention(query, key, value, mask): """Dot product self-attention. Args: query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d) key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d) value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k) Returns: jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k) """ assert query.shape[-1] == key.shape[-1] == value.shape[ -1], "Embedding dimensions of q, k, v aren't all the same" depth = query.shape[-1] # Calculate scaled query key dot product according to formula above dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth) if mask is not None: # The 'None' in this line does not need to be replaced dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax formula implementation logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) dots = jnp.exp(dots - logsumexp) attention = jnp.matmul(dots, value) return attention
def _per_head_attention(queries, keys, values, mask, dropout, mode, rng): """Computes new per-head activations via scaled dot-product attention. This function is the core of the attention mechanism. Given per-head ``queries`` (Q), ``keys`` (K), ``values`` (V), and ``mask``, it: - computes the scaled dot product of each Q-K pair; - applies ``mask`` to screen out positions that come from padding tokens (indicated by 0 value); - [in ``'train'`` mode] applies dropout to Q-K dot products; - computes Q-K attention strengths using a per-query softmax of the Q-K dot products; and - for each query position, combines V vectors according to the Q-K attention strengths. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention strengths. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. Applies only in ``'train'`` mode. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Tuple of (activations, attn_strengths), where activations are new per-head activation vectors and attn_strengths is a matrix of per-head attention strengths. """ if dropout >= 1.0: raise ValueError(f'Dropout rate ({dropout}) must be lower than 1.') d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) attn_strengths = ( jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True))) if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, attn_strengths.shape) attn_strengths = jnp.where(keep, attn_strengths / (1.0 - dropout), jnp.zeros_like(attn_strengths)) activations = jnp.matmul(attn_strengths, values).astype(jnp.float32) attn_strengths = attn_strengths.astype(jnp.float32) return activations, attn_strengths
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: # TODO(kitaev): workaround for https://github.com/google/jax/issues/850 # We must ensure that both mask and the -1e9 constant have a data dependency # on the input. Broadcasted copies of these use a lot of memory, so they # should be computed at runtime (rather than being global constants). if fastmath.is_backend(fastmath.Backend.JAX): mask = jax.lax.tie_in(dots, mask) # JAX's `full_like` already ties in -1e9 to dots. dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def DotProductAttention(queries, keys, values, pos_emb, context_bias, location_bias, mask, separate_cls, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. pos_emb: Per-head activations representing positional embeddings. context_bias: Global context bias from Transformer XL's attention. location_bias: Global location bias from Transformer XL's attention. mask: Mask that distinguishes positions with real content vs. padding. separate_cls: True/False if we separate_cls in calculations. dropout: Probabilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] keys_len, queries_len = keys.shape[-2], queries.shape[-2] funnel_factor, is_upsampling = calc_funnel_ratio(keys_len, queries_len) ac = jnp.einsum('bnid,bnjd->bnij', queries + context_bias, keys) bd = jnp.einsum('bnid,jnd->bnij', queries + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd, funnel_factor, is_upsampling) if separate_cls: # Masking out location part of attention for cls token bd = bd.at[:, :, :, 0].set(0) bd = bd.at[:, :, 0, :].set(0) dots = (ac + bd) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def DotProductAttention(queries, keys, values, mask, dropout, mode, rng): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head ``queries`` and ``keys``, - applies ``mask`` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head ``values`` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probababilistic rate for attention dropout, which overrides (sets to zero) some attention strengths derived from query-key matching. As a result, on a given forward pass, some value vectors don't contribute to the output, analogous to how regular dropout can cause some node activations to be ignored. mode: One of ``'train'``, ``'eval'``, or ``'predict'``. rng: Single-use random number generator (JAX PRNG key). Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ d_feature = queries.shape[-1] dots = jnp.matmul(queries, jnp.swapaxes(keys, -1, -2)) / jnp.sqrt(d_feature) if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) out = jnp.matmul(dots, values) out = out.astype(jnp.float32) dots = dots.astype(jnp.float32) return out, dots
def DotProductAttention(query, key, value, mask): assert query.shape[-1] == key.shape[-1] == value.shape[-1] depth = query.shape[-1] dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt( depth) # Part of dot product formula # Apply mask if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Rest of dot product attention formula logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) dots = jnp.exp(dots - logsumexp) attention = jnp.matmul(dots, value) return attention
def DotProductAttention(query, key, value, mask): """Dot product self-attention. Args: query (jax.interpreters.xla.DeviceArray): array of query representations with shape (L_q by d) key (jax.interpreters.xla.DeviceArray): array of key representations with shape (L_k by d) value (jax.interpreters.xla.DeviceArray): array of value representations with shape (L_k by d) where L_v = L_k mask (jax.interpreters.xla.DeviceArray): attention-mask, gates attention with shape (L_q by L_k) Returns: jax.interpreters.xla.DeviceArray: Self-attention array for q, k, v arrays. (L_q by L_k) """ assert query.shape[-1] == key.shape[-1] == value.shape[ -1], "Embedding dimensions of q, k, v aren't all the same" # scaling down (Q. K) dot product with square root of depth depth = query.shape[-1] # Calculate scaled query key dot product according to formula above dots = jnp.matmul(query, jnp.swapaxes(key, -1, -2)) / jnp.sqrt(depth) # Apply the mask if mask is not None: dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax formula implementation # Use trax.fastmath.logsumexp of dots to avoid underflow by division by large numbers logsumexp = trax.fastmath.logsumexp(dots, axis=-1, keepdims=True) # Note: softmax = e^(dots - logsumexp(dots)) = E^dots / sumexp(dots) dots = jnp.exp(dots - logsumexp) # Multiply dots by value to get self-attention # Use jnp.matmul() attention = jnp.matmul(dots, value) return attention
def log_gaussian_diag_pdf(x, mu, diag_sigma): # pylint: disable=invalid-name """Returns `log N(x | mu, eye(diag_sigma))`. Args: x: <tbd> mu: <tbd> diag_sigma: <tbd> """ a = mu.shape[-1] * jnp.log(2 * jnp.pi) b = jnp.sum(jnp.log(diag_sigma), axis=-1) y = x - mu / diag_sigma y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def log_gaussian_pdf(x, mu, sigma): # pylint: disable=invalid-name """Returns `log N(x | mu, sigma)`. Args: x: <tbd> mu: <tbd> sigma: <tbd> """ a = mu.shape[-1] * jnp.log(2 * jnp.pi) _, b = jnp.linalg.slogdet(sigma) y = jnp.linalg.solve(sigma, x - mu) y = jnp.expand_dims(y, axis=-1) xm = jnp.expand_dims(x - mu, axis=-2) c = jnp.matmul(xm, y) c = jnp.squeeze(jnp.squeeze(c, axis=-1), axis=-1) return -0.5 * (a + b + c)
def DotProductAttention(queries, keys, values, pos_emb, context_bias, location_bias, mask, dropout, mode, rng, chunk_len, chunk_offset): """Computes new activations via masked attention-weighted sum of values. This function is the core of the attention mechanism. It: - computes per-head attention weights from per-head `queries` and `keys`, - applies `mask` to screen out positions that come from padding tokens, - optionally applies dropout to attention weights, and - uses attention weights to combine per-head `values` vectors. Args: queries: Per-head activations representing attention queries. keys: Per-head activations representing attention keys. values: Per-head activations to be combined by computed attention weights. pos_emb: Per-head activations representing positional embeddings. context_bias: Global context bias from Transformer XL's attention. location_bias: Global location bias from Transformer XL's attention. mask: Mask that distinguishes positions with real content vs. padding. dropout: Probabilistic rate for dropout applied to attention strengths (based on query-key pairs) before applying them to values. mode: One of `'train'`, `'eval'`, or `'predict'`. rng: Single-use random number generator (JAX PRNG key). chunk_len (optional): Number of tokens per chunk. Setting this option will enable chunked attention. chunk_offset (optional): Offset for shifting chunks, for shifted chunked attention. Returns: Per-head activations resulting from masked per-head attention-weighted sum of per-head values. """ batch_size, n_heads, original_l, d_feature = queries.shape def _calc_attn_scores(q, k): ac = jnp.einsum('bnid,bnjd->bnij', q + context_bias, k) bd = jnp.einsum('bnid,jnd->bnij', q + location_bias, pos_emb) if mode != 'predict': bd = _fast_matrix_shift(bd) dots = (ac + bd) / jnp.sqrt(d_feature) dots = jnp.where(mask, dots, jnp.full_like(dots, -1e9)) # Softmax. dots = jnp.exp(dots - fastmath.logsumexp(dots, axis=-1, keepdims=True)) if dropout >= 1.0: raise ValueError('Dropout rates must be lower than 1.') if dropout is not None and dropout > 0.0 and mode == 'train': keep = fastmath.random.bernoulli(rng, 1.0 - dropout, dots.shape) dots = jnp.where(keep, dots / (1.0 - dropout), jnp.zeros_like(dots)) return dots if chunk_len is None or mode == 'predict': full_dots = _calc_attn_scores(queries, keys) out = jnp.matmul(full_dots, values) else: assert original_l % chunk_len == 0 and original_l >= chunk_len def chunk_split(v): total_len = v.shape[2] assert total_len % chunk_len == 0 n_chunks = total_len // chunk_len chunked_shape = (batch_size, n_heads, n_chunks, chunk_len, d_feature) v = jnp.reshape(v, chunked_shape) v = v.swapaxes(1, 2) return jnp.reshape(v, (batch_size * n_chunks, n_heads, chunk_len, d_feature)) def chunk_join(v, total_len=original_l): assert total_len % chunk_len == 0 n_chunks = total_len // chunk_len swapped_shape = (batch_size, n_chunks, n_heads, chunk_len, d_feature) v = jnp.reshape(v, swapped_shape) v = v.swapaxes(1, 2) return jnp.reshape(v, (batch_size, n_heads, total_len, d_feature)) if chunk_offset == 0: queries, keys, values = map(chunk_split, [queries, keys, values]) chunked_dots = _calc_attn_scores(queries, keys) chunked_result = jnp.matmul(chunked_dots, values) out = chunk_join(chunked_result) else: assert chunk_len > chunk_offset last_chunk_len = chunk_len - chunk_offset def split_along_l(v, mid_start, mid_end, end): pre = jnp.take(v, indices=range(mid_start), axis=2) mid = jnp.take(v, indices=range(mid_start, mid_end), axis=2) post = jnp.take(v, indices=range(mid_end, end), axis=2) return pre, mid, post def pad_to_chunk_len(v): width = [(0, 0)] * v.ndim width[2] = (0, chunk_len - v.shape[2]) return jnp.pad(v, width, mode='constant', constant_values=0.0) def pad_borders(v): total_len = v.shape[2] pre, mid, post = split_along_l(v, chunk_offset, total_len - last_chunk_len, total_len) pre, post = map(pad_to_chunk_len, [pre, post]) return jnp.concatenate([pre, mid, post], axis=2) def unpad_borders(v): padded_total_len = v.shape[2] assert padded_total_len == original_l + chunk_len pre_padded, mid, post_padded = split_along_l( v, chunk_len, padded_total_len - chunk_len, padded_total_len) pre = jnp.take(pre_padded, indices=range(chunk_offset), axis=2) post = jnp.take(post_padded, indices=range(last_chunk_len), axis=2) return jnp.concatenate([pre, mid, post], axis=2) queries, keys, values = map(lambda x: chunk_split(pad_borders(x)), [queries, keys, values]) permuted_dots = _calc_attn_scores(queries, keys) permuted_out = chunk_join( jnp.matmul(permuted_dots, values), total_len=original_l + chunk_len) out = unpad_borders(permuted_out) out = out.astype(jnp.float32) return out, None # We don't store full dots matrix