def local_dot_product_attention(query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """Computes dot-product attention given query, key, and value. Note: This is equivalent to the dot product attention in flax.nn. However, we do extra broadcasting of the bias in this function. I'm leaving this here incase we need to modify something later. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, dim1, dim2,..., dimN, num_heads, value_channels]`. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] assert (query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis,) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(onp.delete(range(n), axis + (n - 1,))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1,) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1,) + axis value = value.transpose(v_perm) query = query / jnp.sqrt(depth).astype(dtype) batch_dims_t = tuple(range(len(batch_dims))) attn_weights = lax.dot_general( query, key, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), precision=precision) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: bias = bias[:, :, None, :, :] attn_weights = attn_weights + bias # normalize the attention weights norm_dims = tuple(range(attn_weights.ndim - len(axis), attn_weights.ndim)) attn_weights = jax.nn.softmax(attn_weights, axis=norm_dims) attn_weights = attn_weights.astype(dtype) # apply dropout if not deterministic and dropout_rate > 0.: if dropout_rng is None: dropout_rng = make_rng() keep_prob = jax.lax.tie_in(attn_weights, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = attn_weights.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) multiplier = (keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)) attn_weights = attn_weights * multiplier # compute the new values given the attention weights wv_contracting_dims = (norm_dims, range(value.ndim - len(axis), value.ndim)) y = lax.dot_general( attn_weights, value, (wv_contracting_dims, (batch_dims_t, batch_dims_t)), precision=precision) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) y = y.transpose(perm_inv) return y
def dot_product_attention_Modified(query, key, value, dtype=jnp.float32, bias=None, axis=None, broadcast_dropout=True, dropout_rng=None, dropout_rate=0., deterministic=False, precision=None): """DEPRECATION WARNING: "The `flax.nn` module is Deprecated, use `flax.linen` instead. Learn more and find an upgrade guide at https://github.com/google/flax/blob/master/flax/linen/README.md" Computes dot-product attention given query, key, and value. This is the core function for applying attention based on https://arxiv.org/abs/1706.03762. It calculates the attention weights given query and key and combines the values using the attention weights. This function supports multi-dimensional inputs. Args: query: queries for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. key: keys for calculating attention with shape of `[batch_size, dim1, dim2, ..., dimN, num_heads, mem_channels]`. value: values to be used in attention with shape of `[batch_size, dim1, dim2,..., dimN, num_heads, value_channels]`. dtype: the dtype of the computation (default: float32) bias: bias for the attention weights. This can be used for incorporating autoregressive mask, padding mask, proximity bias. axis: axises over which the attention is applied. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate deterministic: bool, deterministic or not (to apply dropout) precision: numerical precision of the computation see `jax.lax.Precision` for details. Returns: Output of shape `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`. """ assert key.shape[:-1] == value.shape[:-1] assert (query.shape[0:1] == key.shape[0:1] and query.shape[-1] == key.shape[-1]) if axis is None: axis = tuple(range(1, key.ndim - 2)) if not isinstance(axis, Iterable): axis = (axis,) assert key.ndim == query.ndim assert key.ndim == value.ndim for ax in axis: if not (query.ndim >= 3 and 1 <= ax < query.ndim - 2): raise ValueError('Attention axis must be between the batch ' 'axis and the last-two axes.') depth = query.shape[-1] n = key.ndim # batch_dims is <bs, <non-attention dims>, num_heads> batch_dims = tuple(np.delete(range(n), axis + (n - 1,))) # q & k -> (bs, <non-attention dims>, num_heads, <attention dims>, channels) qk_perm = batch_dims + axis + (n - 1,) key = key.transpose(qk_perm) query = query.transpose(qk_perm) # v -> (bs, <non-attention dims>, num_heads, channels, <attention dims>) v_perm = batch_dims + (n - 1,) + axis value = value.transpose(v_perm) query = query / jnp.sqrt(depth).astype(dtype) batch_dims_t = tuple(range(len(batch_dims))) #softMax on key key_dims=tuple(range(key.ndim - len(axis), key.ndim)) key_soft = softmax(key, axis=key_dims) key_soft = key_soft.astype(dtype) # carry out the dot product between softMax(key)T and value part_results = lax.dot_general( key_soft, value, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), precision=precision) # apply attention bias: masking, droput, proximity bias, ect. if bias is not None: part_results = part_results + bias # apply dropout if not deterministic and dropout_rate > 0.: if dropout_rng is None: dropout_rng = make_rng() keep_prob = jax.lax.tie_in(part_results, 1.0 - dropout_rate) if broadcast_dropout: # dropout is broadcast across the batch+head+non-attention dimension dropout_dims = part_results.shape[-(2 * len(axis)):] dropout_shape = (tuple([1] * len(batch_dims_t)) + dropout_dims) keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, part_results.shape) multiplier = (keep.astype(part_results.dtype) / jnp.asarray(keep_prob, dtype=dtype)) part_results = part_results * multiplier # carry out the dot product between query and part_results results = lax.dot_general( query, part_results, (((n - 1,), (n - 1,)), (batch_dims_t, batch_dims_t)), precision=precision) # normalize the results norm_dims = tuple(range(results.ndim - len(axis), results.ndim)) results = results/cmath.sqrt(norm_dims) # back to (bs, dim1, dim2, ..., dimN, num_heads, channels) perm_inv = _invert_perm(qk_perm) results = results.transpose(perm_inv) return results