def ConfigurableTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None, enc_dec_attention_sparsity=0): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. encoder_decoder_attention_type: The attention layer to use for the encoder-decoder attention. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_decoder_attention_type, enc_dec_attention_sparsity) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.SparseDenseWithOptions( # vec_d ..... output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode), )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=D_MODEL, d_ff=D_FF, n_encoder_layers=N_LAYERS, n_decoder_layers=N_LAYERS, n_heads=N_HEADS, max_len=MAX_SEQUENCE_LENGTH, dropout=DROPOUT_RATE, dropout_shared_axes=DROPOUT_SHARED_AXES, mode=MODE, ff_activation=FF_ACTIVATION_TYPE): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: Array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length), where sequence_length <= ``max_len``. Array elements are integers in ``range(input_vocab_size)``, and 0 values mark padding positions. - target: Array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length), where sequence_length <= ``max_len``. Array elements are integers in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 3-D array of raw activations with last/innermost dimension of ``output_vocab_size``, suitable for decoding into a batch of token strings; shape is (batch_size, sequence_length, ``vocab_size``). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in ``range(vocab_size)``. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if ``None``, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each encoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within encoder/decoder blocks. The same rate is also used for attention dropout in encoder/decoder blocks. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If ``'predict'``, use fast inference. If ``'train'``, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of :py:class:`Layer`. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. encoder_mode = 'eval' if mode == 'predict' else mode # Share embedding weights if no separate output vocab size. in_embedder = tl.Embedding(input_vocab_size, d_model) if output_vocab_size is None: out_embedder = in_embedder output_vocab_size = input_vocab_size else: out_embedder = tl.Embedding(output_vocab_size, d_model) def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) def _EncBlock(): return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) def _Encoder(): encoder = tl.Serial( in_embedder, _Dropout(), tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), [_EncBlock() for _ in range(n_encoder_layers)], tl.LayerNorm(), ) return tl.Cache(encoder) if mode == 'predict' else encoder def _EncDecBlock(): return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d return tl.Serial( tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d _Encoder(), # Decode. tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... tl.ShiftRight(mode=mode), out_embedder, _Dropout(), tl.PositionalEncoding(max_len=max_len, mode=mode), tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... [_EncDecBlock() for _ in range(n_decoder_layers)], tl.LayerNorm(), tl.Select([0], n_in=3), # Drops masks and encoding vectors. # Map vectors to match output vocab size. tl.Dense(output_vocab_size), )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: source, target. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a source, target pair to activations over a vocab set. """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock( d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_encoder_layers)] encoder = tl.Serial( in_encoder, encoder_blocks, tl.LayerNorm() ) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ _EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch( [], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) # Positional encodings are not shared between encoder and decoder. # Since encoder doesn't run stepwise, we do not use predict mode there. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=encoder_mode) ] out_encoder = out_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=mode) ] if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax # Transformer, but may not be the right option in general. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=0.0, mode=mode) return [ tl.Embedding(d_model, vocab_size), # TODO(kitaev): BroadcastedDropout? tl.Dropout(rate=dropout, mode=mode), positional_encoding, ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_encoder_layers) ] encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch(in_encoder, tl.PaddingMask()), # vec_e masks tok_d ..... tl.Dup(), # vec_e1 vec_e2 masks tok_d ..... tl.ReversibleSerial(encoder_blocks), # vec_e1 vec_e2 masks tok_d ..... # The two sets of activations need to be reduced to one, in this case by # averaging them. Note that ReformerLM concatenates instead. Various # options (concat, average, add, keep only one, etc.) seem to perform # similarly. We don't concatenate here because we want exact parameter # parity with the standard Transformer. tl.Fn(lambda x, y: (x + y) / 2.0), # vec_e masks tok_d ..... tl.LayerNorm(), # vec_e masks tok_d ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(), # tok_d masks vec_e ..... out_encoder, # vec_d masks vec_e ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks vec_e ..... tl.Dup(), # vec_d1 vec_d2 masks vec_e ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x + y) / 2.0), # vec_d masks vec_e ..... tl.LayerNorm(), # vec_d masks vec_e ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )