def TransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.ShiftRight(), tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), tl.Serial(*[ DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def PositionLookupTransformerLM(vocab_size=128, d_feature=256, d_feedforward=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial([ tl.ShiftRight(), tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers)], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax() ])
def TransformerRevnetLM(vocab_size, d_feature=512, d_feedforward=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_loop_stride=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_feature: int: depth of *each half* of the two-part features d_feedforward: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_loop_stride: int: number of query elements to compute attention for in parallel. Set to 0 to disable memory-efficient attention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), # TODO(kitaev): add dropout tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_feature, d_feedforward, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_loop_stride, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], sections=n_chunks), )
def TransformerRevnetLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, attention_type=DotProductAttention, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(n_items=n_chunks), tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_kv=False, max_len=2048, mode='train'): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_kv: bool, whether to share keys and values in decoder attention max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model( # tokens tl.ShiftRight(), # toks embedder, # vecs [ DecoderBlock( # pylint: disable=g-complex-comprehension d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_kv, i, mode) for i in range(n_layers) ], # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs )
def Transformer(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] encoder = [ tl.Branch(positional_embedder, tl.PaddingMask()), [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), ] return tl.Model( tl.Parallel([], tl.ShiftRight()), tl.Parallel(encoder, positional_embedder), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('decoder', ('mask', 'decoder'), 'encoder')), # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Parallel([], tl.EncoderDecoderMask(), []), [ EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask and encoder. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def ChunkedTransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, chunk_selector=None, max_len=2048, mode='train'): """Transformer language model operating on chunks. The input to this model is a sequence presented as a list or tuple of chunks: (chunk1, chunk2, chunks3, ..., chunkN). Each chunk should have the same shape (batch, chunk-length) and together they represent a long sequence that's a concatenation chunk1,chunk2,...,chunkN. Chunked Transformer emulates the operation of a Transformer on this long sequence except for the chunked attention layer, which may attend to only a subset of the chunks to reduce memory use. Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend (if None, attends to the previous chunks which is equivalent to setting chunk_selector(x) = [] if x < 1 else [x-1] (TransformerXL); we attend to the current chunk with a causal mask too, selected chunks unmasked). max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the layer. """ stack = [ ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, chunk_selector, mode) for _ in range(num_layers) ] # Below each Map(L) applies the layer L to each chunk independently. return tl.Serial( tl.ShiftRight(), tl.Map(tl.Embedding(feature_depth, vocab_size)), tl.Map(tl.Dropout(rate=dropout, mode=mode)), tl.PositionalEncoding(max_len=max_len), tl.Serial(*stack), tl.Map(tl.LayerNorm()), tl.Map(tl.Dense(vocab_size)), tl.Map(tl.LogSoftmax()), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = layers.Serial(layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len)) encoder = layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(embedding, layers.PaddingMask()), layers.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), layers.Parallel(layers.LayerNorm(), layers.Identity())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return layers.Serial( layers.Parallel(layers.Identity(), layers.ShiftRight()), layers.Parallel(encoder, embedding), layers.UnnestBranches(), # (encoder, encoder_mask, decoder_input) layers.Reorder(output=(0, (1, 2), 2)), layers. Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask layers.Identity(), layers.EncoderDecoderMask(), layers.Identity()), layers.Serial(*stack), layers.ThirdBranch(), layers.LayerNorm(), layers.Dense(vocab_size), layers.LogSoftmax())
def TransformerRevnetLM(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=32, n_attention_chunks=8, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_feature: int: depth of *each half* of the two-part features d_feedforward: int: depth of feed-forward layer n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for memory-efficient attention mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), # TODO(kitaev): dropout is disabled to save memory # tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.Concatenate(), tl.ShiftRight(), positional_embedder, Duplicate(), # pylint: disable=no-value-for-parameter ReversibleSerial([ DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode) for _ in range(n_layers) ]), tl.Parallel(tl.LayerNorm(), tl.LayerNorm()), tl.Concatenate(), Split(sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map([ tl.Dense(vocab_size), tl.LogSoftmax(), ]), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len)) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.Parallel(tl.LayerNorm(), tl.NoOp())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return tl.Serial( tl.Parallel(tl.NoOp(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('encoder', ('mask', 'decoder'), 'decoder')), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def model(mode): del mode return layers.Serial( layers.Parallel( layers.Flatten(), # Observation stack. layers.Embedding(d_feature=1, vocab_size=n_actions), # Action. ), layers.Concatenate(), layers.Dense(n_units=1), layers.Dup(), layers.Parallel( layers.Dense(n_units=obs_shape[1]), # New observation. None, # Reward. ) )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model([ # tokens tl.Dup(), # toks toks tl.Parallel(embedder, tl.PaddingMask()), # vecs mask [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ], # vecs mask tl.Parallel([], tl.Drop()), # ____ 0 tl.LayerNorm(), # vecs tl.Mean(axis=1), # Average on length. # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs ])
def NeuralGPU(d_feature=96, steps=16, vocab_size=2): """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727. Args: d_feature: Number of memory channels (dimensionality of feature embedding). steps: Number of times depthwise recurrence steps. vocab_size: Vocabulary size. Returns: A NeuralGPU Stax model. """ core = ConvDiagonalGRU(units=d_feature) return tl.Model( tl.Embedding(d_feature=d_feature, vocab_size=vocab_size), [core] * steps, tl.Dense(vocab_size), tl.LogSoftmax(), )
def NeuralGPU(feature_depth=96, steps=16, vocab_size=2): """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727. Args: feature_depth: Number of memory channels steps: Number of times depthwise recurrence steps. vocab_size: Vocabulary size. Returns: A NeuralGPU Stax model. """ xs = [] xs.append(tl.Embedding(feature_depth=feature_depth, vocab_size=vocab_size)) core = ConvDiagonalGRU(units=feature_depth) xs.extend([core] * steps) xs.append(tl.Dense(vocab_size)) xs.append(tl.LogSoftmax()) return tl.Serial(*xs)
def TransformerLM(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model( # tokens tl.ShiftRight(), # toks embedder, # vecs [ DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs )
def TransformerEncoder(vocab_size, n_classes=10, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer encoder. Args: vocab_size: int: vocab size n_classes: how many classes on output d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer encoder layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return [ tl.Branch(positional_embedder, tl.PaddingMask()), # Create mask. [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask. tl.LayerNorm(), tl.Mean(axis=1), # Average on length. tl.Dense(n_classes), tl.LogSoftmax(), ]
def TransformerEncoder(vocab_size, num_classes=10, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer encoder. Args: vocab_size: int: vocab size num_classes: how many classes on output feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer encoder layer. """ input_embedding = layers.Serial( layers.Embedding(feature_depth, vocab_size), layers.Dropout(rate=dropout, mode=mode), layers.PositionalEncoding(max_len=max_len) ) return layers.Serial( layers.Branch(), # Branch input to create embedding and mask. layers.Parallel(input_embedding, layers.PaddingMask()), layers.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)]), layers.FirstBranch(), # Drop the mask. layers.LayerNorm(), layers.Mean(axis=1), # Average on length. layers.Dense(num_classes), layers.LogSoftmax() )
def TransformerLM(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Model( tl.ShiftRight(), positional_embedder, [ DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_qk=False, max_len=2048, n_chunks=0, mode='train'): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys in decoder attention max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ if n_chunks == 0: concatenate_chunks = split_chunks = [] else: concatenate_chunks = tl.Concatenate(n_items=n_chunks) split_chunks = tl.Split(n_sections=n_chunks, axis=-2) embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] return tl.Model( # tokens (or chunked tuple of tokens) concatenate_chunks, # tokens tl.ShiftRight(mode=mode), # toks embedder, # vecs [DecoderBlock( # pylint: disable=g-complex-comprehension d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, i, mode) for i in range(n_layers)], # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs split_chunks, # vecs (or chunked tuple of vecs) )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Model( # tokens_e tokens_d tl.Swap(), # toks_d toks_e # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def Transformer(source_vocab_size, target_vocab_size, mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.1, shared_embedding=True, max_len=200, return_evals=False): """Transformer model. Args: source_vocab_size: int: source vocab size target_vocab_size: int: target vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) shared_embedding: bool: specify whether source/target embeddings are tied. max_len: int: maximum symbol length for positional encoding return_evals: bool: whether to generate decode-time evaluation functions Returns: A namedtuple containing model 'init' and 'apply' functions for training and the 'evals' functions that itself returns a namedtuple containing evaluation functions for the trained encoder, decoder, and generator substax. """ # Input embedding and positional encoding inject_position = layers.Serial( layers.Dropout(dropout, mode=mode), layers.PositionalEncoding(feature_depth, max_len=max_len)) if shared_embedding: assert source_vocab_size == target_vocab_size # Weight-shared Embedding embedding = layers.Share( layers.Embedding(feature_depth, source_vocab_size)) source_embedding_layer = layers.Serial(embedding, inject_position) target_embedding_layer = source_embedding_layer else: source_embedding = layers.Embedding(feature_depth, source_vocab_size) target_embedding = layers.Embedding(feature_depth, target_vocab_size) source_embedding_layer = layers.Serial(source_embedding, inject_position) target_embedding_layer = layers.Serial(target_embedding, inject_position) # Multi-headed Attention and Feed-forward layers multi_attention = layers.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) # Encoder @layers.Lambda def Encoder(source, source_mask): """Transformer encoder stack. Args: source: layer variable: raw source sequences source_mask: layer variable: self-attention mask Returns: Layer variable that outputs encoded source. """ encoder_layer = layers.Serial( # input attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value source_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), ) return layers.Serial( source, source_embedding_layer, layers.repeat(encoder_layer, num_layers), layers.LayerNorm(), ) # Decoder @layers.Lambda def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: layer variable: encoded source sequences target: layer variable: raw target sequences target_mask: layer variable: self-attention mask memory_mask: layer variable: memory attention mask Returns: Layer variable that outputs encoded source. """ decoder_layer = layers.Serial( # target attends to self layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query layers.Identity(), # key layers.Identity(), # value target_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # target attends to encoded source layers.Residual( layers.LayerNorm(), layers.Branch(size=4), layers.Parallel( layers.Identity(), # query memory, # key memory, # value memory_mask), # attention mask multi_attention, layers.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)) return layers.Serial( target, target_embedding_layer, layers.repeat(decoder_layer, num_layers), layers.LayerNorm(), ) # The Transformer @layers.Lambda def transformer(source, target, source_mask, target_mask, memory_mask): # pylint: disable=invalid-name encoded_source = Encoder(source, source_mask) return Decoder(encoded_source, target, target_mask, memory_mask) # Finally, bind the generator transform to use later for inference. @layers.Lambda def Generator(encoded_target): return layers.Serial(encoded_target, layers.Dense(target_vocab_size), layers.LogSoftmax) # Model-Building and Evaluation Functions # Get entire model's the layer pair top_init, top_apply = Generator(transformer) # By default act as a normal constructor and emit an (init, apply) pair. if not return_evals: return (top_init, top_apply) else: raise ValueError('inference in this model is still a work in progress')