def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial( tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len) ) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)]), tl.Parallel(tl.LayerNorm(), tl.Copy()) ) stack = [EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)] return tl.Serial( tl.Parallel(tl.Copy(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.UnnestBranches(), # (encoder, encoder_mask, decoder_input) tl.Select((0, (1, 2), 2)), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Copy(), tl.EncoderDecoderMask(), tl.Copy()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax() )
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = tl.Serial(tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME')) shortcut = tl.Copy() if not channel_mismatch else tl.Conv( channels, (3, 3), strides, padding='SAME') return tl.Residual(main, shortcut=shortcut)
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select(((2, 0, 0), 1)), # ((dec, enc, enc), mask) tl.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.Copy(), tl.Copy(), self_attention), tl.Branch(tl.Copy(), encoder_decoder_attention), tl.UnnestBranches(), # (encoder, mask, old_act, new_act) tl.Select((0, 1, (2, 3))), tl.Parallel( # Residual after encoder-decoder attention. tl.Copy(), tl.Copy(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.Copy(), tl.Copy(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def ChunkedCausalMultiHeadedAttention( feature_depth, num_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: feature_depth: int: depth of embedding num_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = tl.Serial( tl.Branch( tl.Branch( # q = k = v = first input tl.Copy(), tl.Copy(), tl.Copy()), tl.CausalMask(axis=-2), ), tl.Parallel( tl.Parallel( tl.Dense(feature_depth), tl.Dense(feature_depth), tl.Dense(feature_depth), ), tl.Copy() ) ) return tl.Serial( tl.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter tl.Map(tl.PureMultiHeadedAttention( feature_depth=feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), check_shapes=False), tl.Map(tl.Select(0), check_shapes=False), # drop masks tl.Map(tl.Dense(feature_depth)) )