def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.Dup(), # vec vec tl.Parallel([], tl.CausalMask(axis=-2)), # vec mask tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # vec tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def MultiHeadedAttentionPosition(positions, d_feature, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions( tl.Dense(d_feature), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_feature)), PreservePosition(tl.Dense(d_feature)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_feature)), )
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # TODO(jonni): Work on combinators so that this flow is cleaner/clearer. tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode), ] decoder_to_encoder_attention = [ tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(decoder_self_attention), tl.Residual(decoder_to_encoder_attention), tl.Residual(feed_forward), ]
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.CausalMask(axis=-2)), # ______ masks ..... ...... tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # ______ 0 ..... ...... tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.MultiHeadedAttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e ]
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( self_attention, tl.Residual(encoder_decoder_attention), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model([ # tokens tl.Dup(), # toks toks tl.Parallel(embedder, tl.PaddingMask()), # vecs mask [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ], # vecs mask tl.Parallel([], tl.Drop()), # ____ 0 tl.LayerNorm(), # vecs tl.Mean(axis=1), # Average on length. # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs ])
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for memory-efficient attention dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Dup(), tl.Parallel([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention( d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] # TODO(kitaev): Memory-efficient attention. This chunking is temporary. self_attention = [ Split(sections=n_attention_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map(self_attention), tl.Concatenate(axis=-2), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleResidual([self_attention], [feed_forward]), ]
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Model( # tokens_e tokens_d tl.Swap(), # toks_d toks_e # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )