def EncoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode): """Returns a layer sequence that implements a Transformer encoder block. The input to the layer sequence is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an (activations, mask) pair to an (activations, mask) pair. """ attention = [ tl.LayerNorm(), tl.Attention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, name='enc_attn_dropout', mode=mode), ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx=layer_idx, mode=mode), ] return [ tl.Residual(attention), tl.Residual(feed_forward), ]
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Returns a layer sequence that implements a Transformer decoder block. The input to the layer sequence is an activation tensor. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: A sequence of layers that maps an activation tensor to an activation tensor. """ self_attention = [ tl.LayerNorm(), # vec tl.Dup(), # vec vec tl.Parallel([], tl.CausalMask(axis=-2)), # vec mask tl.Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # vec tl.Dropout(rate=dropout, mode=mode), # vec ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.CausalMask(axis=-2)), # ______ masks ..... ...... tl.Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # ______ 0 ..... ...... tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.LayerNorm(), # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.AttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e ]
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for memory-efficient attention dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Dup(), tl.Parallel([], tl.CausalMask(axis=-2)), # Create mask. tl.Attention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] # TODO(kitaev): Memory-efficient attention. This chunking is temporary. self_attention = [ Split(sections=n_attention_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map(self_attention), tl.Concatenate(axis=-2), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleResidual([self_attention], [feed_forward]), ]