def MultiRNNCell(): """Multi-layer RNN cell.""" return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)))
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_qk=False, max_len=2048, n_chunks=0, mode='train'): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys in decoder attention max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ if n_chunks == 0: concatenate_chunks = split_chunks = [] else: concatenate_chunks = tl.Concatenate(n_items=n_chunks) split_chunks = tl.Split(n_sections=n_chunks, axis=-2) embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] return tl.Serial( # tokens (or chunked tuple of tokens) concatenate_chunks, # tokens tl.ShiftRight(mode=mode), # toks embedder, # vecs [ DecoderBlock( # pylint: disable=g-complex-comprehension d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, i, mode) for i in range(n_layers) ], # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs split_chunks, # vecs (or chunked tuple of vecs) )