def Encoder(source, source_mask): """Transformer encoder stack. Args: source: layer variable: raw source sequences source_mask: layer variable: self-attention mask Returns: Layer variable that outputs encoded source. """ encoder_layer = layers.Serial( # input attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value source_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), ) return layers.Serial( source, source_embedding_layer, layers.repeat(encoder_layer, num_layers), layers.LayerNorm(), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len)) encoder = layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(embedding, layers.PaddingMask()), layers.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), layers.Parallel(layers.LayerNorm(), layers.Identity())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return layers.Serial( layers.Parallel(layers.Identity(), layers.ShiftRight()), layers.Parallel(encoder, embedding), layers.UnnestBranches(), # (encoder, encoder_mask, decoder_input) layers.Reorder(output=(0, (1, 2), 2)), layers. Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()), layers.Serial(*stack), layers.ThirdBranch(), layers.LayerNorm(), layers.Dense(vocab_size), layers.LogSoftmax())
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = layers.Residual( layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = layers.Serial( layers.Reorder(output=((2, 0, 0), 1)), # ((dec, enc, enc), mask) layers.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new v feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), ) return layers.Serial( layers.Parallel(layers.Identity(), layers.Identity(), self_attention), layers.Branch(), layers.Parallel(layers.Identity(), encoder_decoder_attention), layers.UnnestBranches(), # (encoder, mask, old_act, new_act) layers.Reorder(output=(0, 1, (2, 3))), layers.Parallel( # Residual after encoder-decoder attention. layers.Identity(), layers.Identity(), layers.SumBranches()), layers.Parallel( # Feed-forward on the third component (decoder). layers.Identity(), layers.Identity(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = layers.Serial( layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), strides, padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(channels, (3, 3), padding='SAME')) shortcut = layers.Identity() if not channel_mismatch else layers.Conv( channels, (3, 3), strides, padding='SAME') return layers.Serial(layers.Branch(), layers.Parallel(main, shortcut), layers.SumBranches())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = layers.Serial(layers.Conv(filters1, (1, 1)), layers.BatchNorm(), layers.Relu(), layers.Conv(filters2, (ks, ks), padding='SAME'), layers.BatchNorm(), layers.Relu(), layers.Conv(filters3, (1, 1)), layers.BatchNorm()) return layers.Serial(layers.Branch(), layers.Parallel(main, layers.Identity()), layers.SumBranches(), layers.Relu())
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ # The encoder block expects (activation, mask) as input and returns # the new activations only, we add the mask back to output next. encoder_block = layers.Serial( layers.Residual( # Attention block here. layers.Parallel(layers.LayerNorm(), layers.Identity()), layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), shortcut=layers.FirstBranch() ), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode) ) # Now we add the mask back. return layers.Serial( layers.Reorder(output=((0, 1), 1)), # (x, mask) --> ((x, mask), mask) layers.Parallel(encoder_block, layers.Identity()) )