def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.Copy()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.Copy())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer encoder block. The input to the layer sequence is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an (activations, mask) pair to an (activations, mask) pair. """ attention = [ tl.LayerNorm(), tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.Dup(), # vec vec tl.Parallel([], tl.CausalMask(axis=-2)), # vec mask tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # vec tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder block. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (activations, mask). """ attention = [ tl.LayerNorm(), tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return layers.Serial( layers.Residual( # Self-attention block. layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = layers.Residual( layers.LayerNorm(), layers.Branch(), layers.Parallel( layers.Identity(), # activation for (q, k, v) layers.CausalMask(axis=-2)), # attention mask layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = layers.Serial( layers.Reorder(output=((2, 0, 0), 1)), # ((dec, enc, enc), mask) layers.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new v feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), ) return layers.Serial( layers.Parallel(layers.Identity(), layers.Identity(), self_attention), layers.Branch(), layers.Parallel(layers.Identity(), encoder_decoder_attention), layers.UnnestBranches(), # (encoder, mask, old_act, new_act) layers.Reorder(output=(0, 1, (2, 3))), layers.Parallel( # Residual after encoder-decoder attention. layers.Identity(), layers.Identity(), layers.SumBranches()), layers.Parallel( # Feed-forward on the third component (decoder). layers.Identity(), layers.Identity(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((2, 0, 0, 1)), # (dec, enc, enc, mask) tl.MultiHeadedAttentionQKV( # (q, k, v, mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.NoOp(), tl.NoOp(), self_attention), tl.Branch(tl.NoOp(), encoder_decoder_attention), tl.Select(inputs=(('encoder', 'mask', 'old_act'), 'new_act'), output=('encoder', 'mask', ('old_act', 'new_act'))), tl.Parallel( # Residual after encoder-decoder attention. tl.NoOp(), tl.NoOp(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.NoOp(), tl.NoOp(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # TODO(jonni): Work on combinators so that this flow is cleaner/clearer. tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode), ] decoder_to_encoder_attention = [ tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(decoder_self_attention), tl.Residual(decoder_to_encoder_attention), tl.Residual(feed_forward), ]
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.CausalMask(axis=-2)), # ______ masks ..... ...... tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # ______ 0 ..... ...... tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.MultiHeadedAttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e ]
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( self_attention, tl.Residual(encoder_decoder_attention), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for memory-efficient attention dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] # TODO(kitaev): Memory-efficient attention. This chunking is temporary. self_attention = [ Split(sections=n_attention_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map(self_attention), tl.Concatenate(axis=-2), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleResidual([self_attention], [feed_forward]), ]
def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ # The encoder block expects (activation, mask) as input and returns # the new activations only, we add the mask back to output next. encoder_block = layers.Serial( layers.Residual( # Attention block here. layers.Parallel(layers.LayerNorm(), layers.Identity()), layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), layers.Dropout(rate=dropout, mode=mode), shortcut=layers.FirstBranch() ), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode) ) # Now we add the mask back. return layers.Serial( layers.Reorder(output=((0, 1), 1)), # (x, mask) --> ((x, mask), mask) layers.Parallel(encoder_block, layers.Identity()) )
def Transformer(source_vocab_size, target_vocab_size, mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.1, shared_embedding=True, max_len=200, return_evals=False): """Transformer model. Args: source_vocab_size: int: source vocab size target_vocab_size: int: target vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) shared_embedding: bool: specify whether source/target embeddings are tied. max_len: int: maximum symbol length for positional encoding return_evals: bool: whether to generate decode-time evaluation functions Returns: A namedtuple containing model 'init' and 'apply' functions for training and the 'evals' functions that itself returns a namedtuple containing evaluation functions for the trained encoder, decoder, and generator substax. """ # Input embedding and positional encoding inject_position = layers.Serial( layers.Dropout(dropout, mode=mode), layers.PositionalEncoding(feature_depth, max_len=max_len)) if shared_embedding: assert source_vocab_size == target_vocab_size # Weight-shared Embedding embedding = layers.Share( layers.Embedding(feature_depth, source_vocab_size)) source_embedding_layer = layers.Serial(embedding, inject_position) target_embedding_layer = source_embedding_layer else: source_embedding = layers.Embedding(feature_depth, source_vocab_size) target_embedding = layers.Embedding(feature_depth, target_vocab_size) source_embedding_layer = layers.Serial(source_embedding, inject_position) target_embedding_layer = layers.Serial(target_embedding, inject_position) # Multi-headed Attention and Feed-forward layers multi_attention = layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) # Encoder @layers.Lambda def Encoder(source, source_mask): """Transformer encoder stack. Args: source: layer variable: raw source sequences source_mask: layer variable: self-attention mask Returns: Layer variable that outputs encoded source. """ encoder_layer = layers.Serial( # input attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value source_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), ) return layers.Serial( source, source_embedding_layer, layers.repeat(encoder_layer, num_layers), layers.LayerNorm(), ) # Decoder @layers.Lambda def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), ) # The Transformer @layers.Lambda def transformer(source, target, source_mask, target_mask, memory_mask): # pylint: disable=invalid-name encoded_source = Encoder(source, source_mask) return Decoder(encoded_source, target, target_mask, memory_mask) # Finally, bind the generator transform to use later for inference. @layers.Lambda def Generator(encoded_target): return layers.Serial(encoded_target, layers.Dense(target_vocab_size), layers.LogSoftmax) # Model-Building and Evaluation Functions # Get entire model's the layer pair top_init, top_apply = Generator(transformer) # By default act as a normal constructor and emit an (init, apply) pair. if not return_evals: return (top_init, top_apply) else: raise ValueError('inference in this model is still a work in progress')