def _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls): """Internal funnel block. Returns a list of layers implementing it. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper). Returns: A list of layers that maps (activations, mask) to (activations', mask). """ pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls) mask_pooling = MaskPool(pool_size, strides, separate_cls) attention = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) hidden_dropout = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # h, mask tl.LayerNorm(), # h, mask tl.Branch(pooling, None), # h', h, mask tl.Residual( tl.Select([0, 1, 1, 2]), # h', h, h, mask attention, # attn, mask tl.Parallel(None, mask_pooling), # attn, mask' hidden_dropout # attn, mask' ), # funnel_activations, mask' tl.Residual(feed_forward) ]
def EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode): """Returns a layer sequence that implements a Transformer encoder block. The input to the layer sequence is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an (activations, mask) pair to an (activations, mask) pair. """ attention = [ tl.LayerNorm(), tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode), ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return tl.Serial( tl.Residual(attention), tl.Residual(feed_forward), )
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). mode (str): 'train' or 'eval'. ff_activation (function): the non-linearity in feed-forward layer. Returns: list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor. """ # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks return [ tl.Residual( # Normalize layer input tl.LayerNorm(), # Add causal attention tl.CausalAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode)), tl.Residual( # Add feed-forward block # We don't need to normalize the layer inputs here. The feed-forward block takes care of that for us. FeedForward(d_model, d_ff, dropout, mode, ff_activation)), ]
def Encoder(d_model, d_ff, n_heads, dropout, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). mode (str): 'train' or 'eval'. ff_activation (function): the non-linearity in feed-forward layer. Returns: list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = [ tl.LayerNorm(), tl.Dense(d_ff), ff_activation(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, mode=mode) ] return [ tl.Residual(tl.LayerNorm(), causal_attention, tl.Dropout(rate=dropout, mode=mode)), tl.Residual(feed_forward), ]
def _DecoderBlock(positions, d_model, d_ff, n_heads, dropout, mode): """Returns a layer sequence representing a Transformer decoder. (acts, pos) --> (acts', pos') Args: positions: random vectors for positions d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), AttentionPosition(positions=positions, d_model=d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode)), tl.Residual( tl.LayerNorm(), tl.Dense(d_ff), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, mode=mode), ))
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers implementing a Transformer encoder-decoder block. The input is a triple (decoder_activations, mask, encoder_activiations) where the mask is created from the original input token IDs to prevent attending to the padding part of the encoder. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. Returns: A list of layers which maps triples (decoder_activations, mask, encoder_activations) to triples of the same sort. """ def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode, cache_KV_in_predict=True) causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # vec_d masks vec_e tl.Residual( tl.LayerNorm(), # vec_d ..... ..... causal_attention, # vec_d ..... ..... _Dropout(), # vec_d ..... ..... ), tl.Residual( tl.LayerNorm(), # vec_d ..... ..... tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e attention_qkv, # vec_d masks vec_e _Dropout(), # vec_d masks vec_e ), tl.Residual(feed_forward # vec_d masks vec_e ), ]
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input to the block is a pair (activations, mask) where the mask encodes causal connections, preventing attention to future positions in the sequence. The block's outputs are the same type/shape as its inputs, so that multiple blocks can be chained together. Args: d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within decoder blocks. The same rate is also used for attention dropout in decoder blocks. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If ``'train'``, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of :py:class:`Layer`. Returns: A list of layers that act in series as a (repeatable) decoder block. """ def _CausalAttention(): return tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), def _FFBlock(): return _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( tl.LayerNorm(), _CausalAttention(), _Dropout(), ), tl.Residual( tl.LayerNorm(), _FFBlock(), _Dropout(), ), ]
def EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, FeedForwardBlock=FeedForwardBlock): """ Returns a list of layers that implements a Transformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). dropout_shared_axes (int): axes on which to share dropout mask. mode (str): 'train' or 'eval'. ff_activation (function): the non-linearity in feed-forward layer. FeedForwardBlock (function): A function that returns the feed forward block. Returns: list: A list of layers that maps (activations, mask) to (activations, mask). """ # Attention block attention = tl.Attention( # dimension of the model d_feature=d_model, # number of attention heads n_heads=n_heads, # `dropout` dropout=dropout, # `mode` mode=mode) # calling function `FeedForwardBlock feed_forward = FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) # Dropout block dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) encoder_block = [ # `Residual` layer tl.Residual( tl.LayerNorm(), attention, dropout_, ), tl.Residual(feed_forward, ), ] return encoder_block
def _FunnelRelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, total_pooling, shorten_factor, resampler_fn): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. total_pooling: total pooling. shorten_factor: by how much shorten/upsample at this funnel block. resampler_fn: Type of function that performs funnel upsampling/downsampling; callable with signature: shorten_factor, d_model; must return an activation-type subclass of `Layer`. Returns: A list of layers that maps an activation tensor to an activation tensor. """ resampler = resampler_fn(shorten_factor, d_model) attention = RelativeAttentionLMLayer( d_model, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = _FeedForwardBlock( d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout( rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.LayerNorm(), # h tl.Branch(tl.Serial( resampler, tl.LayerNorm(), ), None), # h', h tl.Residual( tl.Select([0, 1, 1]), # h', h, h attention, dropout_, ), tl.Residual( feed_forward ), ]
def _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling): """Returns a list of layers that implements a Transformer encoder block. The input to the block is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. context_bias_layer: Global context bias from Transformer XL's attention. location_bias_layer: Global location bias from Transformer XL's attention. total_pooling: The combined pool size of previously used funnel blocks. Returns: A list of layers that maps (activations, att_vecs, mask) to (activations, att_vecs, mask). """ attention = RelativeAttentionLMLayer( d_model, context_bias_layer, location_bias_layer, total_pooling, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = _FeedForwardBlock( d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout( rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( # vecs tl.LayerNorm(), tl.Select([0, 0, 0]), attention, dropout_, ), # vecs tl.Residual( feed_forward ), # vecs ]
def WideResnetGroup(n, channels, strides=(1, 1), bn_momentum=0.9, mode='train'): shortcut = [ tl.Conv(channels, (3, 3), strides, padding='SAME'), ] return [ tl.Residual(WideResnetBlock(channels, strides, bn_momentum=bn_momentum, mode=mode), shortcut=shortcut), tl.Residual([WideResnetBlock(channels, (1, 1), bn_momentum=bn_momentum, mode=mode) for _ in range(n - 1)]), ]
def DecoderBlock(embeddingDepth, depth, n_heads, dropout, mode, ffActivationffActivation): return [ tl.Residual( tl.LayerNorm(), tl.CausalAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode)), tl.Residual( FeedForward(embeddingDepth, depth, dropout, mode, ffActivation)), ]
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode, ff_activation): """Returns a list of layers implementing a Transformer encoder-decoder block. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers which maps triples (decoder_activations, mask, encoder_activations) to triples of the same sort. """ def _Dropout(): return tl.Dropout(rate=dropout, mode=mode) attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) basic_causal_attention = tl.BasicCausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, ff_activation) return [ # vec_d masks vec_e tl.Residual( tl.LayerNorm(), # vec_d ..... ..... basic_causal_attention, # vec_d masks ..... _Dropout(), # vec_d ..... ..... ), tl.Residual( tl.LayerNorm(), # vec_d ..... ..... tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e attention_qkv, # vec_d masks vec_e _Dropout(), # vec_d masks vec_e ), tl.Residual(feed_forward # vec_d masks vec_e ), ]
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode, ff_activation): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.BasicCausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.LayerNorm(), # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.AttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx, mode, ff_activation), ] return tl.Serial( # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e )
def ConvBlock(kernel_size, filters, strides, norm, non_linearity, mode='train'): """ResNet convolutional striding block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1), strides), norm(mode=mode), non_linearity(), tl.Conv(filters2, (ks, ks), padding='SAME'), norm(mode=mode), non_linearity(), tl.Conv(filters3, (1, 1)), norm(mode=mode), ] shortcut = [ tl.Conv(filters3, (1, 1), strides), norm(mode=mode), ] return [tl.Residual(main, shortcut=shortcut), non_linearity()]
def NMTAttn(input_vocab_size=33300, target_vocab_size=33300, d_model=1024, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=4, attention_dropout=0.0, mode='train'): input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers) pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model) model = tl.Serial( tl.Select([0, 1, 0, 1]), tl.Parallel(input_encoder, pre_attention_decoder), tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4), # nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries) tl.Residual( tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)), # Step 6: drop attention mask (i.e. index = None tl.Select([0, 2]), [tl.LSTM(d_model) for _ in range(n_decoder_layers)], tl.Dense(target_vocab_size), tl.LogSoftmax()) return model
def DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode): """Transformer decoder layer. Args: positions: random vectors for positions d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return [ tl.Residual( # Self-attention block. PreservePosition(tl.LayerNorm()), tl.Dup(), tl.Parallel( [], # activation for (q, k, v) tl.CausalMask(axis=-2)), # attention mask AttentionPosition(positions, d_model, n_heads=n_heads, dropout=dropout, mode=mode), PreservePosition(tl.Dropout(rate=dropout, mode=mode))), ResidualFeedForward(d_model, d_ff, dropout, mode=mode) ]
def NMTAttn(input_vocab_size=33300, target_vocab_size=33300, d_model=1024, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=4, attention_dropout=0.0, mode='train'): """Returns an LSTM sequence-to-sequence model with attention. The input to the model is a pair (input tokens, target tokens), e.g., an English sentence (tokenized) and its translation into German (tokenized). Args: input_vocab_size: int: vocab size of the input target_vocab_size: int: vocab size of the target d_model: int: depth of embedding (n_units in the LSTM cell) n_encoder_layers: int: number of LSTM layers in the encoder n_decoder_layers: int: number of LSTM layers in the decoder after attention n_attention_heads: int: number of attention heads attention_dropout: float, dropout for the attention layer mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: A LSTM sequence-to-sequence model with attention. """ # creation of input encoder for encoder activations input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers) # creation of layers for the pre-attention decoder pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model) # Model model = tl.Serial( # copy input tokens and target tokens for later use. tl.Select([0, 1, 0, 1]), # parellel run of input encoder on the input and pre-attention decoder the target. tl.Parallel(input_encoder, pre_attention_decoder), # preparation of queries, keys, values and mask for attention. tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4), # AttentionQKV layer nested it inside a Residual layer to add to the pre-attention decoder activations tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)), tl.Select([0, 2]), # run the rest of the RNN decoder [tl.LSTM(n_units=d_model) for _ in range(n_decoder_layers)], # Dense layer of target size tl.Dense(target_vocab_size), #Log-softmax for output tl.LogSoftmax() ) return model
def DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode): """Transformer decoder layer. (acts, pos) --> (acts', pos') Args: positions: random vectors for positions d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), AttentionPosition(positions=positions, d_model=d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode) ), ResidualFeedForward(d_model, d_ff, dropout, mode=mode) )
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def _DecoderBlock(d_model, d_ff, n_heads, d_attn_key, d_attn_value, attn_type, dropout, share_qk, layer_idx, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads d_attn_key: int: depth of key vector for each attention head d_attn_value: int: depth of value vector for each attention head attn_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, d_attention_key=d_attn_key, d_attention_value=d_attn_value, attention_type=attn_type, share_qk=share_qk, mode=mode), dropout_ = tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, ff_activation) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation): causal_attention = CausalAttention(d_model, n_heads=n_heads, mode=mode) # Shallow neural netowork with layer normalization and dropout to avoid overfitting feed_forward = [ tl.LayerNorm(), tl.Dense(d_ff), ff_activation(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, mode=mode) ] # This creates the residual network which is used to ensure the model is able to understand complex relationships as well as simple # Sometimes when models become too deep they don't learn properly which is why this is needed return [ tl.Residual(tl.LayerNorm(), causal_attention, tl.Dropout(rate=dropout, mode=mode)), tl.Residual(feed_forward), ]
def _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps (activations, mask) to (activations, mask). """ attention = tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( tl.LayerNorm(), attention, dropout_, ), tl.Residual(feed_forward), ]
def DecoderBlock(d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, layer_idx, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.CausalAttention(d_model, n_heads=n_heads, d_attention_key=d_attention_key, d_attention_value=d_attention_value, attention_type=attention_type, share_qk=share_qk, mode=mode), tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode), ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return tl.Serial( tl.Residual(self_attention), tl.Residual(feed_forward), )
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def ResidualFeedForward(d_model, d_ff, dropout, mode): """Residual feed-forward layer with normalization at start.""" stack = tl.Serial( tl.LayerNorm(), tl.Dense(d_ff), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, mode=mode), ) return tl.Residual(stack)
def ResidualSwitchUnit( d_model, dropout=0.1, mode='train', residual_weight=0.9): r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. As defined in the paper: .. math:: i &= [i_1, i_2] \\ g &= GELU(LayerNorm(Z i)) \\ c &= W g + B \\ [o_1, o_2] &= \sigma(S) \bigodot i + h \bigodot c where Z, W, B, S are learnable parameters with sizes 2m × 4m, 4m × 2m, 2m, 2m. We assume that both i_1 and i_2 have size m. h is a scalar value. We assume the input is of shape [batch, length, depth]. Args: d_model: output depth of the SRU layer dropout: dropout rate used in 'train' mode mode: mode for dropout layer residual_weight: value used in initializing vector S and constant h Returns: The RSU layer. """ return tl.Serial( tl.Fn( 'Reshape2Pairs', lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), n_out=1), tl.Residual( tl.Dense(4 * d_model, use_bias=False), tl.LayerNorm(), tl.Gelu(), tl.Dense(2 * d_model), tl.Fn('Scaling', lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25, n_out=1), shortcut=_ClippedScaling(residual_weight)), tl.Fn( 'UnPair', lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), n_out=1), tl.Dropout(rate=dropout, mode=mode) )
def IdentityBlock(kernel_size, filters, norm, non_linearity, mode='train'): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1)), norm(mode=mode), non_linearity(), tl.Conv(filters2, (ks, ks), padding='SAME'), norm(mode=mode), non_linearity(), tl.Conv(filters3, (1, 1)), norm(mode=mode), ] return [ tl.Residual(main), non_linearity(), ]
def IdentityBlock(kernel_size, filters, mode='train'): """ResNet identical size block.""" # TODO(jonni): Use good defaults so Resnet50 code is cleaner / less redundant. ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1)), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters2, (ks, ks), padding='SAME'), tl.BatchNorm(mode=mode), tl.Relu(), tl.Conv(filters3, (1, 1)), tl.BatchNorm(mode=mode), ] return [ tl.Residual(main), tl.Relu(), ]
def ConvBlock(kernel_size, filters, strides, norm, non_linearity, mode='train'): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = [ tl.Conv(filters1, (1, 1), strides), norm(mode=mode), non_linearity(), tl.Conv(filters2, (ks, ks), padding='SAME'), norm(mode=mode), non_linearity(), tl.Conv(filters3, (1, 1)), norm(mode=mode), ] shortcut = [ tl.Conv(filters3, (1, 1), strides), norm(mode=mode), ] return [tl.Residual(main, shortcut=shortcut), non_linearity()]