Exemple #1
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model([  #      tokens
        tl.Dup(),  # toks toks
        tl.Parallel(embedder, tl.PaddingMask()),  # vecs mask
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],  # vecs mask
        tl.Parallel([], tl.Drop()),  # ____  0
        tl.LayerNorm(),  # vecs
        tl.Mean(axis=1),  # Average on length.    # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    ])
Exemple #2
0
def NeuralGPU(d_feature=96, steps=16, vocab_size=2):
    """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727.

  Args:
    d_feature: Number of memory channels (dimensionality of feature embedding).
    steps: Number of times depthwise recurrence steps.
    vocab_size: Vocabulary size.

  Returns:
    A NeuralGPU Stax model.
  """
    core = ConvDiagonalGRU(units=d_feature)
    return tl.Model(
        tl.Embedding(d_feature=d_feature, vocab_size=vocab_size),
        [core] * steps,
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
def TransformerLM(vocab_size,
                  d_feature=512,
                  d_feedforward=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
    """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    embedder = [
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(  # tokens
        tl.ShiftRight(),  # toks
        embedder,  # vecs
        [
            DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],  # vecs
        tl.LayerNorm(),  # vecs
        tl.Dense(vocab_size),  # vecs
        tl.LogSoftmax(),  # vecs
    )
Exemple #4
0
def NeuralGPU(feature_depth=96, steps=16, vocab_size=2):
    """Implementation of Neural GPU: https://arxiv.org/abs/1702.08727.

  Args:
    feature_depth: Number of memory channels
    steps: Number of times depthwise recurrence steps.
    vocab_size: Vocabulary size.

  Returns:
    A NeuralGPU Stax model.
  """
    xs = []
    xs.append(tl.Embedding(feature_depth=feature_depth, vocab_size=vocab_size))
    core = ConvDiagonalGRU(units=feature_depth)
    xs.extend([core] * steps)
    xs.append(tl.Dense(vocab_size))
    xs.append(tl.LogSoftmax())

    return tl.Serial(*xs)
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_feature=512,
                       d_feedforward=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
    """Transformer encoder.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer encoder layer.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return [
        tl.Branch(positional_embedder, tl.PaddingMask()),  # Create mask.
        [
            EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.Select(0),  # Drop mask.
        tl.LayerNorm(),
        tl.Mean(axis=1),  # Average on length.
        tl.Dense(n_classes),
        tl.LogSoftmax(),
    ]
Exemple #6
0
def TransformerEncoder(vocab_size,
                       num_classes=10,
                       feature_depth=512,
                       feedforward_depth=2048,
                       num_layers=6,
                       num_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train'):
  """Transformer encoder.

  Args:
    vocab_size: int: vocab size
    num_classes: how many classes on output
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the Transformer encoder layer.
  """
  input_embedding = layers.Serial(
      layers.Embedding(feature_depth, vocab_size),
      layers.Dropout(rate=dropout, mode=mode),
      layers.PositionalEncoding(max_len=max_len)
  )
  return layers.Serial(
      layers.Branch(),  # Branch input to create embedding and mask.
      layers.Parallel(input_embedding, layers.PaddingMask()),
      layers.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads,
                                   dropout, mode)
                      for _ in range(num_layers)]),
      layers.FirstBranch(),  # Drop the mask.
      layers.LayerNorm(),
      layers.Mean(axis=1),  # Average on length.
      layers.Dense(num_classes),
      layers.LogSoftmax()
  )
Exemple #7
0
def policy_and_value_net(rng_key,
                         batch_observations_shape,
                         num_actions,
                         bottom_layers=None):
  """A policy and value net function."""

  # Layers.
  cur_layers = []
  if bottom_layers is not None:
    cur_layers.extend(bottom_layers)

  # Now, with the current logits, one head computes action probabilities and the
  # other computes the value function.
  # NOTE: The LogSoftmax instead of the Softmax because of numerical stability.
  cur_layers.extend([
      layers.Branch(
          layers.Serial(layers.Dense(num_actions), layers.LogSoftmax()),
          layers.Dense(1))
  ])
  net = layers.Serial(*cur_layers)
  return net.initialize(batch_observations_shape, rng_key), net
Exemple #8
0
def Resnet50(d_hidden=64, n_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    d_hidden: Dimensionality of the first hidden layer (multiplied later).
    n_output_classes: Number of distinct output classes.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a ResNet model with the given parameters.
  """
    del mode
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(d_hidden, (7, 7), (2, 2), 'SAME'),
        tl.BatchNorm(),
        tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [d_hidden, d_hidden, 4 * d_hidden], (1, 1)),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden]),
        IdentityBlock(3, [d_hidden, d_hidden, 4 * d_hidden]),
        ConvBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden], (2, 2)),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden]),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden]),
        IdentityBlock(3, [2 * d_hidden, 2 * d_hidden, 8 * d_hidden]),
        ConvBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden], (2, 2)),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden]),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden]),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden]),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden]),
        IdentityBlock(3, [4 * d_hidden, 4 * d_hidden, 16 * d_hidden]),
        ConvBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden], (2, 2)),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden]),
        IdentityBlock(3, [8 * d_hidden, 8 * d_hidden, 32 * d_hidden]),
        tl.AvgPool(pool_size=(7, 7)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
def TransformerLM(vocab_size,
                  d_feature=512,
                  d_feedforward=2048,
                  n_layers=6,
                  n_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
    """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.ShiftRight(),
        positional_embedder,
        [
            DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode)
            for _ in range(n_layers)
        ],
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
Exemple #10
0
def WideResnet(n_blocks=3,
               widen_factor=1,
               n_output_classes=10,
               bn_momentum=0.9,
               mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    n_blocks: int, number of blocks in a group. total layers = 6n + 4.
    widen_factor: int, widening factor of each group. k=1 is vanilla resnet.
    n_output_classes: int, number of distinct output classes.
    bn_momentum: float, momentum in BatchNorm.
    mode: Whether we are training or evaluating or doing inference.

  Returns:
    The list of layers comprising a WideResnet model with the given parameters.
  """
    return tl.Model(
        tl.ToFloat(),
        tl.Conv(16, (3, 3), padding='SAME'),
        WideResnetGroup(n_blocks,
                        16 * widen_factor,
                        bn_momentum=bn_momentum,
                        mode=mode),
        WideResnetGroup(n_blocks,
                        32 * widen_factor, (2, 2),
                        bn_momentum=bn_momentum,
                        mode=mode),
        WideResnetGroup(n_blocks,
                        64 * widen_factor, (2, 2),
                        bn_momentum=bn_momentum,
                        mode=mode),
        tl.BatchNorm(momentum=bn_momentum, mode=mode),
        tl.Relu(),
        tl.AvgPool(pool_size=(8, 8)),
        tl.Flatten(),
        tl.Dense(n_output_classes),
        tl.LogSoftmax(),
    )
Exemple #11
0
def WideResnet(num_blocks=3,
               hidden_size=64,
               num_output_classes=10,
               mode='train'):
    """WideResnet from https://arxiv.org/pdf/1605.07146.pdf.

  Args:
    num_blocks: int, number of blocks in a group.
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: int, number of classes to distinguish.
    mode: is it training or eval.

  Returns:
    The WideResnet model with given layer and output sizes.
  """
    del mode
    return layers.Serial(layers.Conv(hidden_size, (3, 3), padding='SAME'),
                         WideResnetGroup(num_blocks, hidden_size),
                         WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)),
                         WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)),
                         layers.BatchNorm(), layers.Relu(),
                         layers.AvgPool(pool_size=(8, 8)), layers.Flatten(),
                         layers.Dense(num_output_classes), layers.LogSoftmax())
Exemple #12
0
def PositionLookupTransformerLM(vocab_size=128,
                                d_feature=256,
                                d_feedforward=512,
                                n_layers=3,
                                n_heads=4,
                                dropout=0.1,
                                max_len=100,
                                mode='train'):
    """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_layers: int: number of layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positions = _POSITIONS[:max_len, :]
    return tl.Serial([
        tl.ShiftRight(),
        tl.Embedding(d_feature, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        NewPositionalEncoding(positions=positions),
        [
            DecoderLayer(positions, d_feature, d_feedforward, n_heads, dropout,
                         mode) for _ in range(n_layers)
        ],
        PreservePosition(tl.LayerNorm()),
        tl.Dense(vocab_size),
        tl.LogSoftmax()
    ])
Exemple #13
0
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'):
    """ResNet.

  Args:
    hidden_size: the size of the first hidden layer (multiplied later).
    num_output_classes: how many classes to distinguish.
    mode: whether we are training or evaluating or doing inference.

  Returns:
    The ResNet model with the given layer and output sizes.
  """
    del mode
    return tl.Serial(
        tl.Conv(hidden_size, (7, 7), (2, 2),
                'SAME'), tl.BatchNorm(), tl.Relu(),
        tl.MaxPool(pool_size=(3, 3), strides=(2, 2)),
        ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]),
        ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]),
        ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]),
        ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size],
                  (2, 2)),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]),
        tl.AvgPool(pool_size=(7, 7)), tl.Flatten(),
        tl.Dense(num_output_classes), tl.LogSoftmax())
Exemple #14
0
def TransformerLM(vocab_size,
                  feature_depth=512,
                  feedforward_depth=2048,
                  num_layers=6,
                  num_heads=8,
                  dropout=0.1,
                  max_len=2048,
                  mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  return layers.Serial(
      layers.ShiftRight(),
      layers.Embedding(feature_depth, vocab_size),
      layers.Dropout(rate=dropout, mode=mode),
      layers.PositionalEncoding(max_len=max_len),
      layers.Serial(*[DecoderLayer(feature_depth, feedforward_depth, num_heads,
                                   dropout, mode)
                      for _ in range(num_layers)]),
      layers.LayerNorm(),
      layers.Dense(vocab_size,
                   kernel_initializer=layers.XavierUniformInitializer()),
      layers.LogSoftmax()
  )
Exemple #15
0
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  d_attention_key=None,
                  d_attention_value=None,
                  attention_type=tl.DotProductCausalAttention,
                  dropout=0.1,
                  share_qk=False,
                  max_len=2048,
                  n_chunks=0,
                  mode='train'):
  """Returns a Transformer language model.

  The input to the model is a tensor of tokens. (This model uses only the
  decoder part of the overall Transformer.)

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    d_attention_key: int: depth of key vector for each attention head
        (default is d_model // n_heads)
    d_attention_value: int: depth of value vector for each attention head
        (default is d_model // n_heads)
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: bool, whether to share queries and keys in decoder attention
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  if n_chunks == 0:
    concatenate_chunks = split_chunks = []
  else:
    concatenate_chunks = tl.Concatenate(n_items=n_chunks)
    split_chunks = tl.Split(n_sections=n_chunks, axis=-2)

  embedder = [
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, name='embedding', mode=mode),
      tl.PositionalEncoding(max_len=max_len, mode=mode),
  ]

  return tl.Model(                  # tokens (or chunked tuple of tokens)
      concatenate_chunks,           # tokens
      tl.ShiftRight(mode=mode),     # toks
      embedder,                     # vecs
      [DecoderBlock(  # pylint: disable=g-complex-comprehension
          d_model, d_ff, n_heads, d_attention_key, d_attention_value,
          attention_type, dropout, share_qk, i, mode)
       for i in range(n_layers)],   # vecs
      tl.LayerNorm(),               # vecs
      tl.Dense(vocab_size),         # vecs
      tl.LogSoftmax(),              # vecs
      split_chunks,                 # vecs (or chunked tuple of vecs)
  )
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Model(  # tokens_e tokens_d
        tl.Swap(),  # toks_d toks_e

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Exemple #17
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=32,
               n_attention_chunks=8,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positional_embedder = [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Model(
        tl.Concatenate(n_items=n_chunks),
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial([
            # pylint: disable=g-complex-comprehension
            DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                         n_heads, n_attention_chunks, attention_type, dropout,
                         share_qk, mode) for _ in range(n_layers)
        ] + [
            SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        ]),
        Map(
            [
                # TODO(kitaev): Test whether dropout should go before or after the
                # LayerNorm, and whether dropout broadcasting is needed here.
                tl.LayerNorm(),
                BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
                tl.Dense(vocab_size),
                tl.LogSoftmax(),
            ],
            n_sections=n_chunks),
    )