def TransformerLM(vocab_size, # pylint: disable=invalid-name mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9, max_len=256): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention max_len: int: maximum symbol length for positional encoding Returns: init and apply. """ # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Single decoder layer decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.Dropout(dropout, mode=mode), stax.repeat(decoder_layer, num_layers), stax.LayerNorm(feature_depth), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def encoder(embedded_source, source_mask): """Transformer encoder stack. Args: embedded_source: staxlayer variable: embedded source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( embedded_source, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(feature_depth), )
def encoder(source, source_mask): """Transformer encoder stack. Args: source: staxlayer variable: raw source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode)) ) return stax.serial( source, source_embedding_layer, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(), )
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: init and apply. """ return stax.serial( stax.residual( # Self-attention block. stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask stax.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dropout(dropout, mode=mode) ), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode) )
def decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual( stax.LayerNorm(), stax.FanOut(4), stax.parallel( stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # target attends to encoded source stax.residual( stax.LayerNorm(), stax.FanOut(4), stax.parallel( stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode))) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), )
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode): """Residual feed-forward layer with normalization at start.""" return stax.residual( stax.LayerNorm(), stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()), stax.Dropout(dropout, mode=mode) )