Example #1
0
 def MultiRNNCell():
     """Multi-layer RNN cell."""
     return tl.Serial(
         tl.Parallel([], tl.Split(n_items=n_layers)),
         tl.SerialWithSideOutputs(
             [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
         tl.Parallel([], tl.Concatenate(n_items=n_layers)))
Example #2
0
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  d_attention_key=None,
                  d_attention_value=None,
                  attention_type=tl.DotProductCausalAttention,
                  dropout=0.1,
                  share_qk=False,
                  max_len=2048,
                  n_chunks=0,
                  mode='train'):
    """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    if n_chunks == 0:
        concatenate_chunks = split_chunks = []
    else:
        concatenate_chunks = tl.Concatenate(n_items=n_chunks)
        split_chunks = tl.Split(n_sections=n_chunks, axis=-2)

    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='embedding', mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
    ]

    return tl.Serial(  # tokens (or chunked tuple of tokens)
        concatenate_chunks,  # tokens
        tl.ShiftRight(mode=mode),  # toks
        embedder,  # vecs
        [
            DecoderBlock(  # pylint: disable=g-complex-comprehension
                d_model, d_ff, n_heads, d_attention_key, d_attention_value,
                attention_type, dropout, share_qk, i, mode)
            for i in range(n_layers)
        ],  # vecs
        tl.LayerNorm(),  # vecs
        tl.Dense(vocab_size),  # vecs
        tl.LogSoftmax(),  # vecs
        split_chunks,  # vecs (or chunked tuple of vecs)
    )