def ConfigurableTerraformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, forget_dense=True, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, center_layernorm=True, half_before_layer=None, double_after_layer=None, mode='train'): """Returns a highly configurable Terraformer encoder-decoder model. This model maps paired text sequences (source and target) to float-valued losses. If ``input_vocab_size`` is not ``None``, the layer takes two input sequences: - inputs (2): - source: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(input_vocab_size)``, and 0 values mark padding positions. - target: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 1-D float array of losses; shape is `(batch_size)`. If ``input_vocab_size`` is ``None``, the layer takes three input sequences: - inputs (3): - source: 3-D float array representing a batch of already-embedded text strings; shape is `(batch_size, sequence_length, d_model)`, where sequence_length <= ``max_len``. - mask: 2-D int array representing active versus masked positions; 0 values mark masked (padding) positions. - target: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 1-D float array of losses; shape is `(batch_size)`. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in ``range(vocab_size)``. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if ``None``, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each encoder block. d_attention_key: Depth of key vectors in each attention head. d_attention_value: Depth of value vectors in each attention head. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within encoder/decoder blocks. The same rate is also used for attention dropout in encoder/decoder blocks. max_len: Maximum symbol length for positional encoding. encoder_attention_type: Type of attention to use in the encoder; must be an attention-type subclass of :py:class:`trax.layers.Layer`. encoder_decoder_attention_type: Type of attention to use in the decoder; must be an attention-type subclass of :py:class:`trax.layers.Layer`. pos_type: String indicating the type of positional embeddings to use. pos_axial_shape: Shape (tuple of ints) to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: Tuple of ints specifying the depth of position embedding for each axis. Tuple length must match ``pos_axial_shape``, and values must sum to ``d_model``. pos_start_from_zero_prob: Stochastic rate (probability) for starting positional encoding at position 0 during training. If 1.0, always start from position 0; if < 1.0, the non-zero starts will be uniformly distributed up to ``pos_max_offset_to_add``. pos_max_offset_to_add: Maximum offset to add to positions during training when randomizing. This offset plus input length must be less than ``max_len`` for all training examples. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of :py:class:`trax.layers.Layer`. ff_use_sru: If > 0, use this number of SRU layers in place of feedforward layers. ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this size. ff_dropout: Stochastic rate (probability) for dropping an activation value at feedforward nonlinearities. ff_sparsity: If > 0, use sparse feedforward blocks with this level of sparsity. loss_sparsity_type: String indicating the type of sparsity to used in loss layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, use no sparsity. loss_sparsity: If > 0, use this level of sparsity in the loss layer. loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this dimension, in the loss. loss_sparsity_prob: Stochastic rate (probability) for using the sparse version of the loss. If ``None``, use the sparse version exclusively. attention_chunk_size: If > 0, compute attention using chunks of this size. n_layers_forget: How often to have a forgetting block between layers. forget_dense: If True, use :py:class:`Dense` instances as forget layers; else use no-ops. n_decoder_attention_layers: Number of attention layers in a decoder block. use_bfloat16: If True, use bfloat16 for weights; else use float32. reversible_encoder: If True, make the encoder be reversible. use_two_swaps_per_encoder_block: If True, ensure that there is a an even number of swaps across the encoder. center_layernorm: If True, use centering in :py:class:`LayerNorm` (the default); else omit centering (which is known as RMS normalization). half_before_layer: If not None, specifies an n'th layer such that all layers before the n'th use half the normal values for ``d_model`` and ``d_ff``. double_after_layer: If not None, specifies an n'th layer such that all layers after the n'th use double the normal values for ``d_model`` and ``d_ff``. mode: If ``'train'``, include dropout in each encoder/decoder block; else dropout layers have no effect. Returns: A Terraformer encoder-decoder as a layer that maps from target and source text sequences to a scalar loss. """ if mode == 'predict': portal_mask = _PortalInput() else: portal_mask = None # Set default dimensions for attention head key and value sizes. if (d_model / 2) % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') if d_attention_key is None: d_attention_key = d_model // n_heads if d_attention_value is None: d_attention_value = d_model // n_heads # Set values of d_model, d_ff and d_qkv for the first stage. d_model1, d_ff1 = d_model, d_ff d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value if half_before_layer: d_model1, d_ff1 = d_model / 2, d_ff / 2 d_attention_key1 = d_attention_key / 2 d_attention_value1 = d_attention_value / 2 # Set values of d_model, d_ff and d_qkv for the final stage. d_model2, d_ff2 = d_model, d_ff d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value if double_after_layer: d_model2, d_ff2 = d_model * 2, d_ff * 2 d_attention_key2 = d_attention_key * 2 d_attention_value2 = d_attention_value * 2 # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model1, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, pos_start_from_zero_prob=pos_start_from_zero_prob, pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16)) def _EncoderBlock(): return reformer.EncoderBlock( d_model1, d_ff1, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) def _Encoder(): # vec_e mask_e tok_e tok_d tok_d layers = [ tl.ReversibleSelect([0, 0]), _ReversibleSerialForget( [_EncoderBlock() for _ in range(n_encoder_layers)], d_model1, n_layers_forget, forget_dense) ] if not reversible_encoder: layers += [ _XYAvg(), tl.Dense(d_model1, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] if mode == 'predict': return tl.Cache(tl.Serial(layers)) else: return tl.Serial(layers) if mode == 'predict': # TODO(jaszczur): Remove temporary fix of Terraformer padding in predict. # In predict mode Terraformer needs masking for merged encoder-decoder # sequence. This monkey patches the layer with a mask to neccessary places. # This shouldn't be a permanent solution - mask should be passed through # the stack and all the layers. tl.attention.DotProductCausalAttention.monkey_patched_mask = ( lambda x: portal_mask) tl.research.sparsity._RememberPad.monkey_patched_mask = ( # pylint: disable=protected-access lambda x: portal_mask) originalScanSRUCell = tl.rnn.ScanSRUCell tl.rnn.ScanSRUCell = functools.partial(tl.rnn.ScanSRUCell, monkey_patched_mask=portal_mask) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] # Grow d_model, d_ff, and d_qkv if requested. d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 if half_before_layer and layer_idx >= half_before_layer: d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value if double_after_layer and layer_idx > double_after_layer: d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 decoder_block = reformer.DecoderBlock( d_m, d_f, d_k, d_v, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) if half_before_layer and layer_idx == half_before_layer - 1: decoder_blocks.append(tl.ReversibleConcatenatePair()) if double_after_layer and layer_idx == double_after_layer: decoder_blocks.append(tl.ReversibleConcatenatePair()) if mode == 'predict': # After initializing the decoder we can revert to original state of # previously monkey-patched classes/functions. tl.attention.DotProductCausalAttention.monkey_patched_mask = ( lambda x: None) tl.research.sparsity._RememberPad.monkey_patched_mask = (lambda x: None ) # pylint: disable=protected-access tl.rnn.ScanSRUCell = originalScanSRUCell def _Loss(): return tl.SparseDenseWithOptions(output_vocab_size, d_input=d_model2, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) def _enc_dec_concat(): """Layers to merge encoder and decoder.""" if reversible_encoder: return [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d ] else: return [ tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d ] def _inp_layers(): if input_vocab_size is not None: return tl.AssertFunction( 'bl,br->bld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth tl.Serial( # tok_e tok_d tl.Select([0, 0, 0, 1]), tl.Parallel( in_encoder, [tl.PaddingMask(), _RemoveAxes12() ]))) # vec_e mask_e tok_e tok_d else: # Input in this case is vec_e, mask_e, tok_d. Where all downstream # operations expect tok_e, we give it instead mask_e, expecting that # downstream ops only are looking for padding/not padding. return tl.AssertFunction( 'blf,bl,br->bld,bl,bl,br', # f: in-feature depth, d: out-vector depth tl.Serial( # vec_e mask_e tok_d tl.Select([0, 1, 1, 2]), tl.Parallel(in_encoder, [], _AsTokenIDs()))) # vec_e mask_e tok_e tok_d # Assemble and return the model. return tl.Serial( _inp_layers(), # vec_e mask_e tok_e tok_d tl.Parallel([], portal_mask), tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. # Embed in and out tokens; done together as weights may be shared. tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), out_encoder ]), # vec_e mask_e tok_e vec_d tok_d # Encode; then concat encoder and decoder, given encoder mask. _Encoder(), # vec_e mask_e tok_e vec_d tok_d _enc_dec_concat(), # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, forget_dense), # vec_ed1 vec_ed2 tok_e tok_d _XYAvg(), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector, # then compute loss. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d _Loss(), # vec_d tok_d )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, forget_dense=True, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, center_layernorm=True, half_before_layer=None, double_after_layer=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. pos_start_from_zero_prob: how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize). pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers forget_dense: whether to use Dense or no-op (Serial) as a forget layer. n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder use_two_swaps_per_encoder_block: whether to allow even number of swaps in the encoder center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. half_before_layer: int, half d_model and d_ff before that layer double_after_layer: int, double d_model and d_ff after that layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if (d_model / 2) % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') if d_attention_key is None: d_attention_key = d_model // n_heads if d_attention_value is None: d_attention_value = d_model // n_heads # Set values of d_model, d_ff and d_qkv for the first stage. d_model1, d_ff1 = d_model, d_ff d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value if half_before_layer: d_model1, d_ff1 = d_model / 2, d_ff / 2 d_attention_key1 = d_attention_key / 2 d_attention_value1 = d_attention_value / 2 # Set values of d_model, d_ff and d_qkv for the final stage. d_model2, d_ff2 = d_model, d_ff d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value if double_after_layer: d_model2, d_ff2 = d_model * 2, d_ff * 2 d_attention_key2 = d_attention_key * 2 d_attention_value2 = d_attention_value * 2 # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model1, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, pos_start_from_zero_prob=pos_start_from_zero_prob, pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model1, d_ff1, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget, forget_dense) ] if not reversible_encoder: encoder += [ tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model1, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] # Grow d_model, d_ff, and d_qkv if requested. d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 if half_before_layer and layer_idx >= half_before_layer: d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value if double_after_layer and layer_idx > double_after_layer: d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 decoder_block = DecoderBlock( d_m, d_f, d_k, d_v, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) if half_before_layer and layer_idx == half_before_layer - 1: decoder_blocks.append(tl.ReversibleConcatenatePair()) if double_after_layer and layer_idx == double_after_layer: decoder_blocks.append(tl.ReversibleConcatenatePair()) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model2, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), # Predict mode doesn't work with padding in encoder. Raising an exception # in jitted function isn't possible, so the second next best thing is # to convert every embedding to NaNs, so the user will not get subtly # wrong results, but clearly wrong results. (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, forget_dense), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )