def MultiplicativeConvCausalAttention(d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConvLayer. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. length_kernel_size: Size of convolution kernel on the length dimension. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity d_module = d_feature // sparsity return tl.Serial( tl.Select([0, 0]), # duplicate activations MultiplicativeSparseDense(sparsity, d_feature, d_feature), # shared q, k tl.Select([0, 0, 0]), # use for q, k, v tl.Parallel( [ LocallyConvDense(sparsity, d_module, kernel_size=3, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], [ LocallyConvDense(sparsity, d_module, kernel_size=3, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], [ tl.Concatenate(), # use permuted and original for v LocallyConvDense(sparsity, d_module, kernel_size=1, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], ), tl.DotProductCausalAttention(dropout=dropout, max_inference_length=max_inference_length, mode=mode), tl.MergeHeads(n_heads), )
def Favor(d_feature, n_heads=1, dropout=0.0, numerical_stabilizer=0.001, mode='train'): """Returns a layer that maps (activations, mask) to (new_activations, mask). See the FAVOR paper for details: https://arxiv.org/abs/2006.03555 Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. numerical_stabilizer: float, small number used for numerical stability. mode: One of `'train'`, `'eval'`, or `'predict'`. """ del dropout, mode # not implemented yet but needed in the API def bidirectional_numerator(query_prime, key_prime, value): kvs = jnp.einsum('lbm,lbd->bmd', key_prime, value) return jnp.einsum('lbm,bmd->lbd', query_prime, kvs) def bidirectional_denominator(query_prime, key_prime): all_ones = jnp.ones([query_prime.shape[0]]) ks_sum = jnp.einsum('lbm,l->bm', key_prime, all_ones) return jnp.einsum('lbm,bm->lb', query_prime, ks_sum) def relu(x): return jnp.where(x <= 0, jnp.zeros_like(x), x) def favor(query, key, value, mask): query_prime = relu(query) + numerical_stabilizer key_prime = relu(key) + numerical_stabilizer mask_batch_1_length = jnp.reshape( mask, [key.shape[0] // n_heads, 1, key.shape[1]]).astype(jnp.float32) mask_heads = mask_batch_1_length + jnp.zeros((1, n_heads, 1)) key_prime *= jnp.reshape(mask_heads, [key.shape[0], key.shape[1], 1]) w = bidirectional_numerator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0), jnp.moveaxis(value, 1, 0)) r = bidirectional_denominator(jnp.moveaxis(query_prime, 1, 0), jnp.moveaxis(key_prime, 1, 0)) w = jnp.moveaxis(w, 0, 1) r = jnp.moveaxis(r, 0, 1) r = jnp.reciprocal(r) r = jnp.expand_dims(r, len(r.shape)) renormalized_attention = w * r return renormalized_attention, mask return tl.Serial( tl.Branch( [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], [tl.Dense(d_feature), tl.SplitIntoHeads(n_heads)], ), tl.Fn('FAVOR', favor, n_out=2), tl.Dense(d_feature), )