def AttentionPosition(vec, pos, positions=None, d_model=None, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" new_posns = list( LearnedPosOperations(positions=positions, n_combinations=n_heads) @ (vec, pos)) hq = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads, tile=False)) @ ([ vec, ] + new_posns) hk = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads, tile=True)) @ (vec, pos) hv = tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_model // n_heads) @ vec x, pos = tl.Serial( tl.DotProductCausalAttention(dropout=dropout, mode=mode), CombineHeadsPos(n_heads=n_heads), tl.Dense(d_model)) @ (hq, hk, hv) return x, pos
def MultiplicativeModularCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConnectedLayer. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity return tl.ConfigurableAttention( MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), MultiplicativeModularSparseDense(sparsity, d_feature), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def MultiplicativeConvCausalAttention(d_feature, n_heads=1, sparsity=None, length_kernel_size=3, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConvLayer. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. length_kernel_size: Size of convolution kernel on the length dimension. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity d_module = d_feature // sparsity return tl.Serial( tl.Select([0, 0]), # duplicate activations MultiplicativeSparseDense(sparsity, d_feature, d_feature), # shared q, k tl.Select([0, 0, 0]), # use for q, k, v tl.Parallel( [ LocallyConvDense(sparsity, d_module, kernel_size=3, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], [ LocallyConvDense(sparsity, d_module, kernel_size=3, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], [ tl.Concatenate(), # use permuted and original for v LocallyConvDense(sparsity, d_module, kernel_size=1, length_kernel_size=length_kernel_size), tl.SplitIntoHeads(n_heads) ], ), tl.DotProductCausalAttention(dropout=dropout, max_inference_length=max_inference_length, mode=mode), tl.MergeHeads(n_heads), )
def LowRankCausalAttention(d_feature, n_heads=1, dropout=0.0, max_inference_length=2048, lowrank=64, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses low-rank approximation of kernel in Dense layer for computing Q/K/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. lowrank: The rank of low-rank approximation. mode: One of `'train'`, `'eval'`, or `'predict'`. """ return tl.ConfigurableAttention( LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), LowRankDense(d_feature, lowrank), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def MultiplicativeCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it multiplies each embedding dimension by a scalar specific to each dimension and each head; then it produces Q/K/V by applying the same dense layer to each head. In comparison to standard dense layer for computing Q/K/V, this layer uses less parameters while still being able to express many functions, like a permutation. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: The sparsity of the layer; usually it should be equal to n_heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. mode: One of `'train'`, `'eval'`, or `'predict'`. """ sparsity = n_heads if sparsity is None else sparsity return tl.ConfigurableAttention( MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), MultiplicativeSparseDense(sparsity, d_feature, d_feature), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def ConvCausalAttention(d_feature, n_heads=1, sparsity=None, dropout=0.0, max_inference_length=2048, kernel_size=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses LocallyConvDense instead of Dense layer for computing Q/K/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. sparsity: Number of modules used in LocallyConvDense. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ n_modules = n_heads if sparsity is None else sparsity @assert_shape('...a->...b') def ProcessingLayer(): assert d_feature % n_modules == 0 return LocallyConvDense(n_modules, d_feature // n_modules, kernel_size=kernel_size) return tl.ConfigurableAttention( ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def MultiplicativeModularCausalAttention( # pylint: disable=invalid-name d_feature, sparsity=1, n_heads=1, dropout=0.0, max_inference_length=2048, kernel_size=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, for computing Q/K/V instead of a Dense layer it combines MultiplicativeSparseDense layer with LocallyConnectedLayer. Args: d_feature: Depth/dimensionality of feature embedding. sparsity: The sparsity of the layer; usually it should be equal to n_heads. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. kernel_size: Kernel size used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ assert d_feature % sparsity == 0 @assert_shape('...a->...a') def ProcessingLayer(): # pylint: disable=invalid-name return tl.Serial( MultiplicativeSparseDense(sparsity, d_feature, d_feature), LocallyConnectedDense(sparsity, d_feature // sparsity, kernel_size=kernel_size)) return tl.ConfigurableAttention( ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), ProcessingLayer(), n_heads=n_heads, qkv_attention_layer=tl.DotProductCausalAttention( dropout=dropout, max_inference_length=max_inference_length, mode=mode))
def ModularCausalAttention( d_feature, n_heads=1, dropout=0.0, # pylint: disable=invalid-name max_inference_length=2048, n_modules=1, mode='train'): """Returns a layer that maps activations to activations, with causal masking. Like `CausalAttention`, this layer type represents one pass of multi-head self-attention with causal masking rather than padding-based masking. However, it uses LocallyConnectedDense instead of Dense layer for computing K/Q/V. Args: d_feature: Depth/dimensionality of feature embedding. n_heads: Number of attention heads. dropout: Probababilistic rate for internal dropout applied to attention activations (based on query-key pairs) before dotting them with values. max_inference_length: maximum length for inference. n_modules: Number of modules used in LocallyConnectedDense. mode: One of `'train'`, `'eval'`, or `'predict'`. """ if d_feature % n_heads != 0: raise ValueError( f'Dimensionality of feature embedding ({d_feature}) is not a multiple ' f'of the requested number of attention heads ({n_heads}).') d_head = d_feature // n_heads @assert_shape('bld->hlx') def _split_into_heads(): """Returns a layer that reshapes tensors for multi-headed computation.""" def f(x): batch_size = x.shape[0] seq_len = x.shape[1] # (b_size, seq_len, d_feature) --> (b_size*n_heads, seq_len, d_head) x = x.reshape((batch_size, seq_len, n_heads, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((batch_size * n_heads, seq_len, d_head)) return x return tl.Fn('SplitIntoHeads', f) @assert_shape('hlx->bld') def _merge_heads(): """Returns a layer that undoes splitting, after multi-head computation.""" def f(x): seq_len = x.shape[1] # (b_size*n_heads, seq_len, d_head) --> (b_size, seq_len, d_feature) x = x.reshape((-1, n_heads, seq_len, d_head)) x = x.transpose((0, 2, 1, 3)) x = x.reshape((-1, seq_len, d_head * n_heads)) return x return tl.Fn('MergeHeads', f) @assert_shape('...a->...b') def ProcessingLayer(): # pylint: disable=invalid-name if n_modules == 1: return tl.Dense(d_feature) else: assert d_feature % n_modules == 0 return LocallyConnectedDense(n_modules, d_feature // n_modules) return cb.Serial( cb.Branch( [ProcessingLayer(), _split_into_heads()], [ProcessingLayer(), _split_into_heads()], [ProcessingLayer(), _split_into_heads()], ), tl.DotProductCausalAttention(dropout=dropout, max_inference_length=max_inference_length, mode=mode), _merge_heads(), ProcessingLayer())