def MultiplicativeModularCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConnectedLayer. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity return tl.ConfigurableAttention( MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def LowRankCausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, lowrank=64, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. lowrank: The rank of low-rank approximation. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return tl.ConfigurableAttention( LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it multiplies each embedding dimension by a scalar specific to each dimension and each head; then it produces Q/K/V by applying the same dense layer to each head. In comparison to standard dense layer for computing Q/K/V, this layer uses less parameters while still being able to express many functions, like a permutation. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity return tl.ConfigurableAttention( MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, kernel_size=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses LocallyConvDense instead of Dense layer for computing Q/K/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: Number of modules used in LocallyConvDense. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ n_modules = n_heads if sparsity is None else sparsity @assert_shape('...a->...b') def ProcessingLayer(): assert d_feature % n_modules == 0 return LocallyConvDense(n_modules, d_feature // n_modules, kernel_size=kernel_size) return tl.ConfigurableAttention( ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def Favor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API def bidirectional_numerator(query_prime, key_prime, value): kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) return jnp.einsum('lbm,bmd->lbd', query_prime, kvs) def bidirectional_denominator(query_prime, key_prime): all_ones = jnp.ones([query_prime.shape[0]]) ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) return jnp.einsum('lbm,bm->lb', query_prime, ks_sum) def relu(x): return jnp.where(x <= 0, jnp.zeros_like(x), x) def favor(query, key, value, mask): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer mask_batch_1_length = jnp.reshape( mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention, mask return tl.ConfigurableAttention(tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), tl.Fn('FAVOR', favor, n_out=2), n_heads=n_heads)
def MultiplicativeModularCausalAttention( # pylint: disable=invalid-name d_feature, sparsity=1, n_heads=1, dropout=0.0, max_inference_length=2048, kernel_size=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConnectedLayer. Args: d_feature: Depth/dimensionality of feature embedding. sparsity: The sparsity of the layer; usually it should be equal to n_heads. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ assert d_feature % sparsity == 0 @assert_shape('...a->...a') def ProcessingLayer(): # pylint: disable=invalid-name return tl.Serial( MultiplicativeSparseDense(sparsity, d_feature, d_feature), LocallyConnectedDense(sparsity, d_feature // sparsity, kernel_size=kernel_size)) return tl.ConfigurableAttention( ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def CausalFavor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, precision=None, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head causal attention, but using FAVOR fast attention as in the following paper: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. precision: passed to jnp.einsum to define arithmetic precision. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API def favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value): def body(p, qkv): (q, k, v) = qkv p += jnp.einsum('...m,...d->...md', k, v, precision=precision) x_slice = jnp.einsum('...m,...md->...d', q, p, precision=precision) return p, x_slice p, w = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime, value)) return w, (precision, p, query_prime, key_prime, value) def favor_numerator_bwd(pqkv, w_ct): precision, p, qs, ks, vs = pqkv def body(carry, qkv_xct): p, p_ct = carry q, k, v, x_ct = qkv_xct q_ct = jnp.einsum('...d,...md->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...d,...m->...md', x_ct, q, precision=precision) k_ct = jnp.einsum('...md,...d->...m', p_ct, v, precision=precision) v_ct = jnp.einsum('...md,...m->...d', p_ct, k, precision=precision) p -= jnp.einsum('...m,...d->...md', k, v, precision=precision) return (p, p_ct), (q_ct, k_ct, v_ct) _, (qs_ct, ks_ct, vs_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)), (qs, ks, vs, w_ct), reverse=True) return (None, None, qs_ct, ks_ct, vs_ct) def favor_numerator(init_prefix_sum_value, precision, query_prime, key_prime, value): w, _ = favor_numerator_fwd(init_prefix_sum_value, precision, query_prime, key_prime, value) return w favor_numerator = fastmath.custom_vjp(favor_numerator, favor_numerator_fwd, favor_numerator_bwd) def favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime): def body(p, qk): q, k = qk p += k x = jnp.einsum('...m,...m->...', q, p, precision=precision) return p, x p, r = fastmath.scan(body, init_prefix_sum_value, (query_prime, key_prime)) return r, (precision, query_prime, key_prime, p) def favor_denominator_bwd(qkp, r_ct): precision, qs, ks, p = qkp def body(carry, qkx): p, p_ct = carry q, k, x_ct = qkx q_ct = jnp.einsum('...,...m->...m', x_ct, p, precision=precision) p_ct += jnp.einsum('...,...m->...m', x_ct, q, precision=precision) k_ct = p_ct p -= k return (p, p_ct), (q_ct, k_ct) _, (qs_ct, ks_ct) = fastmath.scan(body, (p, jnp.zeros_like(p)), (qs, ks, r_ct), reverse=True) return (None, None, qs_ct, ks_ct) def favor_denominator(init_prefix_sum_value, precision, query_prime, key_prime): r, _ = favor_denominator_fwd(init_prefix_sum_value, precision, query_prime, key_prime) return r favor_denominator = fastmath.custom_vjp(favor_denominator, favor_denominator_fwd, favor_denominator_bwd) favor_denominator.defvjp(favor_denominator_fwd, favor_denominator_bwd) def relu(x): return jnp.where(x <= 0, jnp.zeros_like(x), x) def favor(query, key, value): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer prefix_sum_tensor_shape = (key.shape[0], key.shape[-1], value.shape[-1]) t_slice_shape = (key.shape[0], key.shape[-1]) init_prefix_sum_value_numerator = jnp.zeros(prefix_sum_tensor_shape) init_prefix_sum_value_denominator = jnp.zeros(t_slice_shape) w = favor_numerator(init_prefix_sum_value_numerator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = favor_denominator(init_prefix_sum_value_denominator, precision, jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention return tl.ConfigurableAttention(core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), core.Dense(d_feature), n_heads=n_heads, qkv_attention_layer=base.Fn( 'CausalFAVOR', favor))