def Encoder(source, source_mask): """Transformer encoder stack. Args: source: layer variable: raw source sequences source_mask: layer variable: self-attention mask Returns: Layer variable that outputs encoded source. """ encoder_layer = layers.Serial( # input attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value source_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), ) return layers.Serial( source, source_embedding_layer, layers.repeat(encoder_layer, num_layers), layers.LayerNorm(), )
def TransformerRevnetLM(vocab_size, d_feature=512, d_feedforward=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_loop_stride=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_feature: int: depth of *each half* of the two-part features d_feedforward: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), # TODO(kitaev): add dropout tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_feature, d_feedforward, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], sections=n_chunks), )
def TransformerRevnetLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_type=DotProductAttention, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def Transformer(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] encoder = [ tl.Branch(positional_embedder, tl.PaddingMask()), [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), ] return tl.Model( tl.Parallel([], tl.ShiftRight()), tl.Parallel(encoder, positional_embedder), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('decoder', ('mask', 'decoder'), 'encoder')), # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Parallel([], tl.EncoderDecoderMask(), []), [ EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask and encoder. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len)) encoder = layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(embedding, layers.PaddingMask()), layers.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), layers.Parallel(layers.LayerNorm(), layers.Identity())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return layers.Serial( layers.Parallel(layers.Identity(), layers.ShiftRight()), layers.Parallel(encoder, embedding), layers.UnnestBranches(), # (encoder, encoder_mask, decoder_input) layers.Reorder(output=(0, (1, 2), 2)), layers. Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()), layers.Serial(*stack), layers.ThirdBranch(), layers.LayerNorm(), layers.Dense(vocab_size), layers.LogSoftmax())
def TransformerRevnetLM(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_feature: int: depth of *each half* of the two-part features d_feedforward: int: depth of feed-forward layer n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for memory-efficient attention mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), # TODO(kitaev): dropout is disabled to save memory # tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(), tl.ShiftRight(), positional_embedder, Duplicate(), # pylint: disable=no-value-for-parameter ReversibleSerial([ DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ]), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len)) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.Parallel(tl.LayerNorm(), tl.NoOp())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return tl.Serial( tl.Parallel(tl.NoOp(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('encoder', ('mask', 'decoder'), 'decoder')), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def TransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.ShiftRight(), tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), tl.Serial(*[ DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder block. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (activations, mask). """ attention = [ tl.LayerNorm(), tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: The layers comprising a chunked decoder. """ return [ Residual( # Self-attention block. tl.Map(tl.LayerNorm()), ChunkedCausalMultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), tl.Map(tl.Dropout(rate=dropout, mode=mode)), ), tl.Map( ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)) ]
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key) ], [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key) ], [ tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value) ], ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode): """Returns a layer sequence that implements a Transformer encoder block. The input to the layer sequence is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an (activations, mask) pair to an (activations, mask) pair. """ attention = [ tl.LayerNorm(), tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode), ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer encoder block. The input to the layer sequence is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an (activations, mask) pair to an (activations, mask) pair. """ attention = [ tl.LayerNorm(), tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode): """Residual feed-forward layer with normalization at start.""" return layers.Residual(layers.LayerNorm(), layers.Dense(feedforward_depth), layers.Relu(), layers.Dropout(rate=dropout, mode=mode), layers.Dense(feature_depth), layers.Dropout(rate=dropout, mode=mode))
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.Dup(), # vec vec tl.Parallel([], tl.CausalMask(axis=-2)), # vec mask tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # vec tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.Map(tl.LayerNorm()), ChunkedCausalMultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), tl.Map(tl.Dropout(rate=dropout, mode=mode)), ), tl.Map(ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode)) )
def PositionLookupTransformerLM(vocab_size=128, d_feature=256, d_feedforward=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial([ tl.ShiftRight(), tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers)], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax() ])
def DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: positions: random vectors for positions d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return [ tl.Residual( # Self-attention block. PreservePosition(tl.LayerNorm()), tl.Dup(), tl.Parallel( [], # activation for (q, k, v) tl.CausalMask(axis=-2)), # attention mask MultiHeadedAttentionPosition(positions, d_feature, n_heads=n_heads, dropout=dropout, mode=mode), PreservePosition(tl.Dropout(rate=dropout, mode=mode))), ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode) ]
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def TransformerDecoder(d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer decoder model. The input to the model is a continuous tensor. Does not shift the input to the right, i.e. the output for timestep t is based on inputs up to timestep t inclusively. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer decoder as a layer that maps from a continuous tensor to a continuous tensor. """ return tl.Model( # vecs tl.PositionalEncoding(max_len=max_len), tl.Dense(d_model), # vecs [DecoderBlock(d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers)], # vecs tl.LayerNorm(), # vecs )
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.BasicCausalAttention( d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention = [ Chunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter [tl.Dense(d_feature), SplitHeads(n_heads=n_heads)], # pylint: disable=no-value-for-parameter ), ] # TODO(kitaev): add dropout if attention_loop_stride < 1: # Use the standard implementation if no loop_stride is provided. attention = DotProductAttention(dropout=None, mode=mode) else: attention = MemoryEfficientDotProductAttention( loop_stride=attention_loop_stride, dropout=None, mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ JoinHeads(), # pylint: disable=no-value-for-parameter tl.Dense(d_feature), Unchunk(sections=n_attention_chunks), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), ReversibleSwap(), ReversibleHalfResidual(feed_forward), ReversibleSwap(), ]
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.BasicCausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.LayerNorm(), # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.AttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return [ # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e ]
def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), )
def FeedForward(d_model, d_ff, dropout, layer_idx, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_ff), tl.Relu(), tl.Dropout(rate=dropout, name='ff_middle_%d' % layer_idx, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, name='ff_final_%d' % layer_idx, mode=mode), ]
def FeedForward(d_feature, d_feedforward, dropout, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_feedforward), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_feature), tl.Dropout(rate=dropout, mode=mode), ]
def FeedForward(d_model, d_ff, dropout, mode): """Feed-forward block with layer normalization at start.""" return [ tl.LayerNorm(), tl.Dense(d_ff), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Relu(), tl.Dense(d_model), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ]