def ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.Map(tl.LayerNorm()), ChunkedCausalMultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), tl.Map(tl.Dropout(rate=dropout, mode=mode)), ), tl.Map(ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode)) )
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout, chunk_selector, mode): """Transformer decoder layer operating on chunks. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: The layers comprising a chunked decoder. """ return [ Residual( # Self-attention block. tl.Map(tl.LayerNorm()), ChunkedCausalMultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, chunk_selector=chunk_selector, mode=mode), tl.Map(tl.Dropout(rate=dropout, mode=mode)), ), tl.Map( ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode)) ]
def ChunkedTransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, chunk_selector=None, max_len=2048, mode='train'): """Transformer language model operating on chunks. The input to this model is a sequence presented as a list or tuple of chunks: (chunk1, chunk2, chunks3, ..., chunkN). Each chunk should have the same shape (batch, chunk-length) and together they represent a long sequence that's a concatenation chunk1,chunk2,...,chunkN. Chunked Transformer emulates the operation of a Transformer on this long sequence except for the chunked attention layer, which may attend to only a subset of the chunks to reduce memory use. Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) chunk_selector: a function from chunk number to list of chunks to attend (if None, attends to the previous chunks which is equivalent to setting chunk_selector(x) = [] if x < 1 else [x-1] (TransformerXL); we attend to the current chunk with a causal mask too, selected chunks unmasked). max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the layer. """ stack = [ ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, chunk_selector, mode) for _ in range(num_layers) ] # Below each Map(L) applies the layer L to each chunk independently. return tl.Serial( tl.ShiftRight(), tl.Map(tl.Embedding(feature_depth, vocab_size)), tl.Map(tl.Dropout(rate=dropout, mode=mode)), tl.PositionalEncoding(max_len=max_len), tl.Serial(*stack), tl.Map(tl.LayerNorm()), tl.Map(tl.Dense(vocab_size)), tl.Map(tl.LogSoftmax()), )
def ChunkedCausalMultiHeadedAttention(d_feature, n_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: d_feature: int: depth of embedding n_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = tl.Serial( tl.Branch( tl.Branch( # q = k = v = first input tl.NoOp(), tl.NoOp(), tl.NoOp()), tl.CausalMask(axis=-2), ), tl.Parallel( tl.Parallel( tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), ), tl.NoOp())) return tl.Serial( tl.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), check_shapes=False), tl.Map(tl.Select(0), check_shapes=False), # drop masks tl.Map(tl.Dense(d_feature)))