def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = layers.Residual( layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = layers.Serial( layers.Reorder(output=((2, 0, 0), 1)), # ((dec, enc, enc), mask) layers.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new v feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), ) return layers.Serial( layers.Parallel(layers.Identity(), layers.Identity(), self_attention), layers.Branch(), layers.Parallel(layers.Identity(), encoder_decoder_attention), layers.UnnestBranches(), # (encoder, mask, old_act, new_act) layers.Reorder(output=(0, 1, (2, 3))), layers.Parallel( # Residual after encoder-decoder attention. layers.Identity(), layers.Identity(), layers.SumBranches()), layers.Parallel( # Feed-forward on the third component (decoder). layers.Identity(), layers.Identity(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len)) encoder = layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(embedding, layers.PaddingMask()), layers.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), layers.Parallel(layers.LayerNorm(), layers.Identity())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return layers.Serial( layers.Parallel(layers.Identity(), layers.ShiftRight()), layers.Parallel(encoder, embedding), layers.UnnestBranches(), # (encoder, encoder_mask, decoder_input) layers.Reorder(output=(0, (1, 2), 2)), layers. Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()), layers.Serial(*stack), layers.ThirdBranch(), layers.LayerNorm(), layers.Dense(vocab_size), layers.LogSoftmax())