Exemplo n.º 1
0
 def test_numeric_dimensions_pass(self):
     layer = tl.AssertFunction(
         '...34->1234,...34', tl.Branch(
             tl.Dropout(rate=0.1),
             tl.Serial(),
         ))
     x = np.ones((1, 2, 3, 4))
     layer(x)
Exemplo n.º 2
0
 def test_two_outputs_pass(self):
     layer = tl.AssertFunction(
         '...cd->...x,...cd',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=2),
             tl.Dropout(rate=0.1),
         ))
     x = np.ones((1, 2, 3, 4))
     layer(x)
Exemplo n.º 3
0
 def test_multi_output_rank_fail(self):
     layer = tl.AssertFunction(
         '...34->...x,...y',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=3),
             tl.Serial(),
         ))
     x = np.ones((1, 2, 3, 4))
     with self.assertRaises(tl.LayerError):
         layer(x)
Exemplo n.º 4
0
 def test_too_many_outputs_fail(self):
     layer = tl.AssertFunction(
         '...cd->...x,...cd,...cd,...cd',
         tl.Branch(
             tl.Flatten(n_axes_to_keep=2),
             tl.Dropout(rate=0.1),
             tl.Serial(),
         ))
     x = np.ones((1, 2, 3, 4))
     with self.assertRaises(tl.LayerError):
         layer(x)
Exemplo n.º 5
0
 def _inp_layers():
     if input_vocab_size is not None:
         return tl.AssertFunction(
             'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
             tl.Serial(  # tok_e tok_d
                 tl.Select([0, 0, 0, 1]),
                 tl.Parallel(
                     in_encoder,
                     [tl.PaddingMask(), _RemoveAxes12()
                      ])))  # vec_e mask_e tok_e tok_d
     else:
         # Input in this case is vec_e, mask_e, tok_d. Where all downstream
         # operations expect tok_e, we give it instead mask_e, expecting that
         # downstream ops only are looking for padding/not padding.
         return tl.AssertFunction(
             'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
             tl.Serial(  # vec_e mask_e tok_d
                 tl.Select([0, 1, 1, 2]),
                 tl.Parallel(in_encoder, [],
                             _AsTokenIDs())))  # vec_e mask_e tok_e tok_d
Exemplo n.º 6
0
 def test_reduce_rank_explicit_fail2(self):
     layer = tl.AssertFunction('abcde->abcd', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     with self.assertRaises(tl.LayerError):
         layer(x)
Exemplo n.º 7
0
 def test_reduce_rank_to_one_pass(self):
     layer = tl.AssertFunction('abcde->x', tl.Flatten(n_axes_to_keep=0))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
Exemplo n.º 8
0
 def test_reduce_rank_explicit_pass(self):
     layer = tl.AssertFunction('xyzab->xyzc', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
Exemplo n.º 9
0
 def test_reduce_rank_ellipsis_pass(self):
     layer = tl.AssertFunction('...ab->...c', tl.Flatten(n_axes_to_keep=3))
     x = np.ones((1, 2, 3, 4, 5))
     layer(x)
Exemplo n.º 10
0
 def test_simple_fail(self):
     layer = tl.AssertFunction('abc->cba', tl.Dropout(rate=0.1))
     x = np.ones((2, 5, 20))
     with self.assertRaises(tl.LayerError):
         layer(x)
Exemplo n.º 11
0
 def test_simple_pass(self):
     layer = tl.AssertFunction('abc->abc', tl.Dropout(rate=0.1))
     x = np.ones((2, 5, 20))
     layer(x)
Exemplo n.º 12
0
def ConfigurableTransformerEncoder(vocab_size,
                                   n_classes=10,
                                   d_model=512,
                                   d_ff=2048,
                                   n_layers=6,
                                   n_heads=8,
                                   max_len=2048,
                                   dropout=0.1,
                                   dropout_shared_axes=None,
                                   mode='train',
                                   ff_activation=tl.Relu,
                                   ff_dropout=0.1,
                                   ff_chunk_size=0,
                                   ff_use_sru=0,
                                   ff_sparsity=0,
                                   ff_sparsity_type='1inN',
                                   attention_chunk_size=0,
                                   attention_type=tl.Attention,
                                   pos_type=None,
                                   pos_axial_shape=None,
                                   pos_d_axial_embs=None):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor should
      be an integer in `range(vocab_size)`. These integers typically represent
      token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
      classification.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
      residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
      pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder block;
      must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use for the encoder part.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        PositionalEncoder(mode, dropout, max_len, pos_type, pos_axial_shape,
                          pos_d_axial_embs)
    ]

    positional_encoder = tl.AssertFunction('...->...d', positional_encoder)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, attention_type)
        for i in range(n_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
Exemplo n.º 13
0
def EmbeddingAndPositionalEncodings(input_vocab_size,
                                    d_model,
                                    mode,
                                    embedding_dropout,
                                    dropout_shared_axes,
                                    max_len,
                                    output_vocab_size=None,
                                    pos_type=None,
                                    pos_axial_shape=None,
                                    pos_d_axial_embs=None,
                                    pos_start_from_zero_prob=1.0,
                                    pos_max_offset_to_add=0,
                                    use_bfloat16=False):
    """Returns the embedder and positional encoder.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    embedding_dropout: Stochastic rate (probability) for dropping an activation
      value when applying dropout after the embedding block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    max_len: Maximum symbol length for positional encoding.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    use_bfloat16: If `True`, use bfloat16 weights instead of the default
      float32; this can save memory but may (rarely) lead to numerical issues.

  Returns:
    A tuple of (input encoder, output encoder, output vocab size used).
  """

    # tokens --> vectors
    def Embedder(vocab_size, embedding_mode):
        if vocab_size is not None:
            embedding = tl.Embedding(vocab_size,
                                     d_model,
                                     use_bfloat16=use_bfloat16)
        else:
            embedding = tl.Dense(d_model, use_bfloat16=use_bfloat16)
        return [
            embedding,
            tl.Dropout(rate=embedding_dropout,
                       shared_axes=dropout_shared_axes,
                       mode=embedding_mode),
        ]

    # NOTE: Positional encodings are not shared between encoder and decoder.

    # Since encoder doesn't run stepwise, we do not use predict mode there.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_embedder = Embedder(input_vocab_size, encoder_mode)
    in_encoder = in_embedder + [
        PositionalEncoder(encoder_mode,
                          dropout=embedding_dropout,
                          max_len=max_len,
                          pos_type=pos_type,
                          pos_axial_shape=pos_axial_shape,
                          pos_d_axial_embs=pos_d_axial_embs,
                          pos_start_from_zero_prob=pos_start_from_zero_prob,
                          pos_max_offset_to_add=pos_max_offset_to_add,
                          use_bfloat16=use_bfloat16)
    ]

    # If output_vocab_size is None, we reuse the same embedding matrix, otherwise
    # we initialize one.
    assert input_vocab_size or output_vocab_size
    if output_vocab_size is None:
        out_embedder = in_embedder
    else:
        out_embedder = Embedder(output_vocab_size, mode)

    out_encoder = out_embedder + [
        PositionalEncoder(mode,
                          dropout=embedding_dropout,
                          max_len=max_len,
                          pos_type=pos_type,
                          pos_axial_shape=pos_axial_shape,
                          pos_d_axial_embs=pos_d_axial_embs,
                          pos_start_from_zero_prob=pos_start_from_zero_prob,
                          pos_max_offset_to_add=pos_max_offset_to_add,
                          use_bfloat16=use_bfloat16)
    ]

    # Set this to the value actually used.
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    if input_vocab_size is None:
        in_encoder = tl.AssertFunction('...a->...b', in_encoder)
    else:
        in_encoder = tl.AssertFunction('...->...d', in_encoder)
    out_encoder = tl.AssertFunction('...->...d', out_encoder)

    return in_encoder, out_encoder, output_vocab_size
Exemplo n.º 14
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              pos_start_from_zero_prob=1.0,
              pos_max_offset_to_add=0,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              forget_dense=True,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              center_layernorm=True,
              half_before_layer=None,
              double_after_layer=None,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  If input_vocab_size is not None, this model expects an input pair: source,
  target. Otherwise, it expects a triple: embedded_source, mask, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    forget_dense: whether to use Dense or no-op (Serial) as a forget layer.
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    half_before_layer: int, half d_model and d_ff before that layer
    double_after_layer: int, double d_model and d_ff after that layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if (d_model / 2) % n_heads != 0:
        raise ValueError(
            f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
    if d_attention_key is None:
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        d_attention_value = d_model // n_heads

    # Set values of d_model, d_ff and d_qkv for the first stage.
    d_model1, d_ff1 = d_model, d_ff
    d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
    if half_before_layer:
        d_model1, d_ff1 = d_model / 2, d_ff / 2
        d_attention_key1 = d_attention_key / 2
        d_attention_value1 = d_attention_value / 2

    # Set values of d_model, d_ff and d_qkv for the final stage.
    d_model2, d_ff2 = d_model, d_ff
    d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
    if double_after_layer:
        d_model2, d_ff2 = d_model * 2, d_ff * 2
        d_attention_key2 = d_attention_key * 2
        d_attention_value2 = d_attention_value * 2

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model1,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs,
            pos_start_from_zero_prob=pos_start_from_zero_prob,
            pos_max_offset_to_add=pos_max_offset_to_add,
            use_bfloat16=use_bfloat16))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model1,
                     d_ff1,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     attention_chunk_size=attention_chunk_size,
                     center_layernorm=center_layernorm,
                     use_bfloat16=use_bfloat16,
                     use_two_swaps_per_block=use_two_swaps_per_encoder_block,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = [  # vec_e mask_e tok_e tok_d tok_d
        tl.ReversibleSelect([0, 0]),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget,
                                forget_dense)
    ]
    if not reversible_encoder:
        encoder += [
            tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
            tl.Dense(d_model1, use_bfloat16=use_bfloat16),
            tl.LayerNorm(),
        ]
    encoder = tl.Serial(encoder)
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        # Grow d_model, d_ff, and d_qkv if requested.
        d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
        if half_before_layer and layer_idx >= half_before_layer:
            d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
        if double_after_layer and layer_idx > double_after_layer:
            d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
        decoder_block = DecoderBlock(
            d_m,
            d_f,
            d_k,
            d_v,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)
        if half_before_layer and layer_idx == half_before_layer - 1:
            decoder_blocks.append(tl.ReversibleConcatenatePair())
        if double_after_layer and layer_idx == double_after_layer:
            decoder_blocks.append(tl.ReversibleConcatenatePair())

    dense_loss_layer = tl.SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model2,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

    # Layers to merge encoder and decoder, see below for details.
    if reversible_encoder:
        encdec_layers = [
            tl.ReversibleSelect([0, 1, 4, 2,
                                 3]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding2(mode=mode),  # vec_ed vec_ed tok_e tok_d
        ]
    else:
        encdec_layers = [
            tl.ReversibleSelect([0, 3, 1,
                                 2]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d
            tl.ReversibleSelect([0, 0]),  # vec_ed vec_ed tok_e tok_d
        ]

    if input_vocab_size is not None:
        # Input in this case is tok_e, tok_d.
        mask_layers = [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]
        inp_layers = tl.Serial([
            tl.Select([0, 0, 0, 1]),  # tok_e tok_e tok_e tok_d
            tl.Parallel(in_encoder, mask_layers)  # vec_e mask_e tok_e tok_d
        ])
        inp_layers = tl.AssertFunction('bt,bu->btf,bt,bt,bu', inp_layers)
    else:
        # Input in this case is vec_e, mask_e, tok_d. Where all downstream
        # operations expect tok_e, we give it instead mask_e, expecting that
        # downstream ops only are looking for padding/not padding.
        make_tok = tl.Fn('MakeTok', lambda mask: mask.astype(jnp.int32))
        inp_layers = tl.Serial([
            tl.Select([0, 1, 1, 2]),  # vec_e mask_e tok_e tok_d
            tl.Parallel(in_encoder, [], make_tok)  # vec_e mask_e tok_e tok_d
        ])
        inp_layers = tl.AssertFunction('btg,bt,bu->btf,bt,bt,bu', inp_layers)

    # Assemble and return the model.
    return tl.Serial(
        inp_layers,  # vec_e mask_e tok_e tok_d

        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 2, 3, 3]),  # vec_e mask_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            [],
            [],
            [],  # vec_e mask_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),

        # Predict mode doesn't work with padding in encoder. Raising an exception
        # in jitted function isn't possible, so the second next best thing is
        # to convert every embedding to NaNs, so the user will not get subtly
        # wrong results, but clearly wrong results.
        (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []),

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Concat encoder and decoder, given encoder mask.
        encdec_layers,

        # Run decoder blocks.
        _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                                forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        dense_loss_layer,  # vec_d tok_d
    )