def DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). mode (str): 'train' or 'eval'. ff_activation (function): the non-linearity in feed-forward layer. Returns: list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor. """ # Add list of two Residual blocks: the attention with normalization and dropout and feed-forward blocks return [ tl.Residual( # Normalize layer input tl.LayerNorm(), # Add causal attention tl.CausalAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode)), tl.Residual( # Add feed-forward block # We don't need to normalize the layer inputs here. The feed-forward block takes care of that for us. FeedForward(d_model, d_ff, dropout, mode, ff_activation)), ]
def Encoder(d_model, d_ff, n_heads, dropout, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). mode (str): 'train' or 'eval'. ff_activation (function): the non-linearity in feed-forward layer. Returns: list: list of trax.layers.combinators.Serial that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode) feed_forward = [ tl.LayerNorm(), tl.Dense(d_ff), ff_activation(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), tl.Dropout(rate=dropout, mode=mode) ] return [ tl.Residual(tl.LayerNorm(), causal_attention, tl.Dropout(rate=dropout, mode=mode)), tl.Residual(feed_forward), ]
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers implementing a Transformer encoder-decoder block. The input is a triple (decoder_activations, mask, encoder_activiations) where the mask is created from the original input token IDs to prevent attending to the padding part of the encoder. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. Returns: A list of layers which maps triples (decoder_activations, mask, encoder_activations) to triples of the same sort. """ def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode, cache_KV_in_predict=True) causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # vec_d masks vec_e tl.Residual( tl.LayerNorm(), # vec_d ..... ..... causal_attention, # vec_d ..... ..... _Dropout(), # vec_d ..... ..... ), tl.Residual( tl.LayerNorm(), # vec_d ..... ..... tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e attention_qkv, # vec_d masks vec_e _Dropout(), # vec_d masks vec_e ), tl.Residual(feed_forward # vec_d masks vec_e ), ]
def test_simple_call(self): layer = tl.CausalAttention(d_feature=4, n_heads=2) x = np.array([[[2, 5, 3, 4], [0, 1, 2, 3], [0, 1, 2, 3],]]) _, _ = layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (1, 3, 4))
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers implementing a Transformer encoder-decoder block. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers which maps triples (decoder_activations, mask, encoder_activations) to triples of the same sort. """ def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode, cache_KV_in_predict=True) causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # vec_d masks vec_e ResidualZero( tl.LayerNorm(), # vec_d ..... ..... causal_attention, # vec_d ..... ..... _Dropout(), # vec_d ..... ..... ), ResidualZero( tl.LayerNorm(), # vec_d ..... ..... tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e attention_qkv, # vec_d masks vec_e _Dropout(), # vec_d masks vec_e ), ResidualZero( tl.LayerNorm(), feed_forward, # vec_d masks vec_e _Dropout(), ), ]
def DecoderBlock(embeddingDepth, depth, n_heads, dropout, mode, ffActivationffActivation): return [ tl.Residual( tl.LayerNorm(), tl.CausalAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode)), tl.Residual( FeedForward(embeddingDepth, depth, dropout, mode, ffActivation)), ]
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def _DecoderBlock(d_model, d_ff, n_heads, d_attn_key, d_attn_value, attn_type, dropout, share_qk, layer_idx, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads d_attn_key: int: depth of key vector for each attention head d_attn_value: int: depth of value vector for each attention head attn_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, d_attention_key=d_attn_key, d_attention_value=d_attn_value, attention_type=attn_type, share_qk=share_qk, mode=mode), dropout_ = tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, ff_activation) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer mode: str: 'train' or 'eval' Returns: the layer. """ pre_attention_qkv = [ tl.LayerNorm(), tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e ] attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) # TODO(kitaev): BroadcastedDropout? post_attention_qkv = tl.Dropout(rate=dropout, mode=mode) pre_causal_attention = tl.LayerNorm() causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode) # TODO(kitaev): BroadcastedDropout? post_causal_attention = tl.Dropout(rate=dropout, mode=mode) feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, mode) return [ # vec_d1 vec_d2 masks vec_e # TODO(kitaev): consider ReversibleAttentionHalfResidual for efficiency ReversibleHalfResidual( [pre_causal_attention, causal_attention, post_causal_attention]), tl.ReversibleSwap(), ReversibleHalfResidual( [pre_attention_qkv, attention_qkv, post_attention_qkv]), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), # vec_d1 vec_d2 masks vec_e ]
def _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) dropout_ = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) return [ ResidualZero( tl.LayerNorm(), causal_attention, dropout_, ), ResidualZero( tl.LayerNorm(), feed_forward, dropout_, ), ]
def DecoderBlock(d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, layer_idx, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.CausalAttention(d_model, n_heads=n_heads, d_attention_key=d_attention_key, d_attention_value=d_attention_value, attention_type=attention_type, share_qk=share_qk, mode=mode), tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode), ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return tl.Serial( tl.Residual(self_attention), tl.Residual(feed_forward), )
def _DecoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode, ff_activation): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A list of layers that maps an activation tensor to an activation tensor. """ causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), dropout_ = tl.Dropout(rate=dropout, name='attention_%d' % layer_idx, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode, ff_activation) return [ tl.Residual( tl.LayerNorm(), causal_attention, dropout_, ), tl.Residual(feed_forward), ]
def _CausalAttention(): return tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)
def _CausalAttention(): return tl.CausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode),