def Transformer(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] encoder = [ tl.Branch(positional_embedder, tl.PaddingMask()), [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), ] return tl.Model( tl.Parallel([], tl.ShiftRight()), tl.Parallel(encoder, positional_embedder), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('decoder', ('mask', 'decoder'), 'encoder')), # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Parallel([], tl.EncoderDecoderMask(), []), [ EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask and encoder. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial( tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len) ) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)]), tl.Parallel(tl.LayerNorm(), tl.Copy()) ) stack = [EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)] return tl.Serial( tl.Parallel(tl.Copy(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.UnnestBranches(), # (encoder, encoder_mask, decoder_input) tl.Select((0, (1, 2), 2)), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Copy(), tl.EncoderDecoderMask(), tl.Copy()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax() )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select(((2, 0, 0), 1)), # ((dec, enc, enc), mask) tl.MultiHeadedAttentionQKV( # ((q, k, v), mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.Copy(), tl.Copy(), self_attention), tl.Branch(tl.Copy(), encoder_decoder_attention), tl.UnnestBranches(), # (encoder, mask, old_act, new_act) tl.Select((0, 1, (2, 3))), tl.Parallel( # Residual after encoder-decoder attention. tl.Copy(), tl.Copy(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.Copy(), tl.Copy(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def MultiHeadedAttentionPosition( positions, d_feature, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions(tl.Dense(d_feature), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_feature)), PreservePosition(tl.Dense(d_feature)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureMultiHeadedAttention( d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_feature)), )
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.Copy(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def __init__(self, residual_layers): self.compute_residual = tl.Serial([ # TODO(jonni): Rewrite without using Select. tl.Select(inputs=('x1_or_y1', 'x2'), output=('x2', 'x1_or_y1', 'x2')), tl.Parallel(residual_layers, [], []), ]) layers = [self.compute_residual, tl.Add()] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.SubtractTop() self.reverse_layers = [self.compute_residual, self.subtract_top]
def EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # TODO(jonni): Work on combinators so that this flow is cleaner/clearer. tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode), ] decoder_to_encoder_attention = [ tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(decoder_self_attention), tl.Residual(decoder_to_encoder_attention), tl.Residual(feed_forward), ]
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Dup(), tl.CausalMask(axis=-2), # Create the self-attention mask. tl.Swap(), # Put mask behind the activations. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Swap(), # Put self-attention mask on top. tl.Drop(), # Drop self-attention mask. tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((0, 2, 2, 1, 2)), # (dec, enc, enc, mask, enc-copy) tl. MultiHeadedAttentionQKV( # (q, k, v, mask, ...) --> (new, mask, ...) feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( self_attention, tl.Residual(encoder_decoder_attention), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def TransformerEncoder(vocab_size, n_classes=10, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer encoder. Args: vocab_size: int: vocab size n_classes: how many classes on output d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer encoder layer. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] return [ tl.Branch(positional_embedder, tl.PaddingMask()), # Create mask. [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask. tl.LayerNorm(), tl.Mean(axis=1), # Average on length. tl.Dense(n_classes), tl.LogSoftmax(), ]
def ChunkedCausalMultiHeadedAttention(d_feature, n_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: d_feature: int: depth of embedding n_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = tl.Serial( tl.Branch( tl.Branch( # q = k = v = first input tl.NoOp(), tl.NoOp(), tl.NoOp()), tl.CausalMask(axis=-2), ), tl.Parallel( tl.Parallel( tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), ), tl.NoOp())) return tl.Serial( tl.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), check_shapes=False), tl.Map(tl.Select(0), check_shapes=False), # drop masks tl.Map(tl.Dense(d_feature)))
def TransformerEncoder(vocab_size, num_classes=10, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer encoder. Args: vocab_size: int: vocab size num_classes: how many classes on output feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer encoder layer. """ input_embedding = tl.Serial( tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len) ) return tl.Serial( tl.Branch(input_embedding, tl.PaddingMask()), tl.Serial(*[EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers)]), tl.Select(0), # Drop the mask. tl.LayerNorm(), tl.Mean(axis=1), # Average on length. tl.Dense(num_classes), tl.LogSoftmax() )
def DecoderBlock(d_feature, d_feedforward, n_heads, n_attention_chunks, dropout, mode): """Reversible transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for memory-efficient attention dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch([], tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] # TODO(kitaev): Memory-efficient attention. This chunking is temporary. self_attention = [ Split(sections=n_attention_chunks, axis=-2), # pylint: disable=no-value-for-parameter Map(self_attention), tl.Concatenate(axis=-2), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ ReversibleResidual([self_attention], [feed_forward]), ]
def ApplyAndQueryPositions(layer, pos): """Execute layer without position and pos-layers on positions. This takes an embedding including position x = (emb, p), and outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p) where pos=[pos1...posn]. Args: layer: layer to be executed without position information. pos: list of layers to be applied to positions. Returns: the result of this application. """ n_heads = len(pos) return tl.Serial( tl.Dup(), CutPosition(), # TODO(lukaszkaiser): Rewrite without using Select. tl.Select(tuple([0] + [(2, 1)] * n_heads)), tl.Parallel(*([layer] + pos)), Unnest(), ConcatenateN(n=n_heads + 1))