def EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder layer. The input to the encoder is a pair (embedded source, mask) where the mask is created from the original source to prevent attending to the padding part of the input. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a pair (actiavtions, mask). """ return tl.Serial( tl.Residual( # Attention block here. tl.Parallel(tl.LayerNorm(), tl.NoOp()), tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Parallel(tl.Dropout(rate=dropout, mode=mode), tl.NoOp())), tl.Parallel( ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode), tl.Div( divisor=2.0) # Mask added to itself in the residual, divide. ))
def Transformer(vocab_size, d_feature=512, d_feedforward=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ positional_embedder = [ tl.Embedding(d_feature, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] encoder = [ tl.Branch(positional_embedder, tl.PaddingMask()), [ EncoderBlock(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.LayerNorm(), ] return [ tl.Parallel(tl.NoOp(), tl.ShiftRight()), tl.Parallel(encoder, positional_embedder), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('decoder', ('mask', 'decoder'), 'encoder')), # (encoder_mask, decoder_input) -> encoder-decoder mask tl.Parallel(tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()), [ EncoderDecoder(d_feature, d_feedforward, n_heads, dropout, mode) for _ in range(n_layers) ], tl.Select(0), # Drop mask and encoder. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), ]
def Transformer(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer. This model expects on input a pair (source, target). Args: vocab_size: int: vocab size (shared source and target). feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: the Transformer model. """ embedding = tl.Serial(tl.Embedding(feature_depth, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len)) encoder = tl.Serial( tl.Branch(embedding, tl.PaddingMask()), tl.Serial(*[ EncoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ]), tl.Parallel(tl.LayerNorm(), tl.NoOp())) stack = [ EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode) for _ in range(num_layers) ] return tl.Serial( tl.Parallel(tl.NoOp(), tl.ShiftRight()), tl.Parallel(encoder, embedding), tl.Select(inputs=(('encoder', 'mask'), 'decoder'), output=('encoder', ('mask', 'decoder'), 'decoder')), tl.Parallel( # (encoder_mask, decoder_input) -> encoder-decoder mask tl.NoOp(), tl.EncoderDecoderMask(), tl.NoOp()), tl.Serial(*stack), tl.Select(2), # Drop encoder and mask. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ return tl.Serial( tl.Residual( # Self-attention block. tl.LayerNorm(), tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop the mask. tl.Dropout(rate=dropout, mode=mode)), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode))
def DecoderBlock(d_feature, d_feedforward, n_heads, dropout, mode): """Transformer decoder layer. Args: d_feature: int: depth of embedding d_feedforward: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer. """ self_attention = [ tl.LayerNorm(), tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)), # Create mask. tl.MultiHeadedAttention(d_feature, n_heads=n_heads, dropout=dropout, mode=mode), tl.Select(0), # Drop mask. tl.Dropout(rate=dropout, mode=mode), ] feed_forward = [ FeedForward(d_feature, d_feedforward, dropout, mode=mode), ] return [ tl.Residual(self_attention), tl.Residual(feed_forward), ]
def AtariCnn(hidden_sizes=(32, 32), output_size=128): # Input's shape = (B, T, H, W, C) return tl.Serial( tl.Div(divisor=255.0), # Have 4 copies of the input, each one shifted to the right by one. tl.Branch( tl.NoOp(), tl.ShiftRight(), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), ), tl.Serial( tl.ShiftRight(), tl.ShiftRight(), tl.ShiftRight(), )), # Concatenated on the last axis. tl.Concatenate(axis=-1), # (B, T, H, W, 4C) tl.Rebatch(tl.Conv(hidden_sizes[0], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Rebatch(tl.Conv(hidden_sizes[1], (5, 5), (2, 2), 'SAME'), 2), tl.Relu(), tl.Flatten(num_axis_to_keep=2), # B, T and rest. tl.Dense(output_size), tl.Relu(), # Eventually this is shaped (B, T, output_size) )
def EncoderDecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer encoder-decoder layer. The input is a triple pair (encoder, mask, decoder_input) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: the layer, returning a triple (encoder, mask, decoder_activations). """ # Decoder self-attending to decoder. self_attention = tl.Residual( tl.LayerNorm(), tl.Branch(tl.NoOp(), tl.CausalMask(axis=-2)), # create mask tl.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop mask tl.Dropout(rate=dropout, mode=mode)) # Decoder attending to encoder. encoder_decoder_attention = tl.Serial( tl.Select((2, 0, 0, 1)), # (dec, enc, enc, mask) tl.MultiHeadedAttentionQKV( # (q, k, v, mask) --> new, mask feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), tl.Select(0), # drop the mask tl.Dropout(rate=dropout, mode=mode), ) return tl.Serial( tl.Parallel(tl.NoOp(), tl.NoOp(), self_attention), tl.Branch(tl.NoOp(), encoder_decoder_attention), tl.Select(inputs=(('encoder', 'mask', 'old_act'), 'new_act'), output=('encoder', 'mask', ('old_act', 'new_act'))), tl.Parallel( # Residual after encoder-decoder attention. tl.NoOp(), tl.NoOp(), tl.Add()), tl.Parallel( # Feed-forward on the third component (decoder). tl.NoOp(), tl.NoOp(), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode)))
def ChunkedCausalMultiHeadedAttention(d_feature, n_heads=8, dropout=0.0, chunk_selector=None, mode='train'): """Transformer-style causal multi-headed attention operating on chunks. Accepts inputs that are a list of chunks and applies causal attention. Args: d_feature: int: depth of embedding n_heads: int: number of attention heads dropout: float: dropout rate chunk_selector: a function from chunk number to list of chunks to attend. mode: str: 'train' or 'eval' Returns: Multi-headed self-attention layer. """ prepare_attention_input = tl.Serial( tl.Branch( tl.Branch( # q = k = v = first input tl.NoOp(), tl.NoOp(), tl.NoOp()), tl.CausalMask(axis=-2), ), tl.Parallel( tl.Parallel( tl.Dense(d_feature), tl.Dense(d_feature), tl.Dense(d_feature), ), tl.NoOp())) return tl.Serial( tl.Map(prepare_attention_input), ChunkedAttentionSelector(selector=chunk_selector), # pylint: disable=no-value-for-parameter tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature, n_heads=n_heads, dropout=dropout, mode=mode), check_shapes=False), tl.Map(tl.Select(0), check_shapes=False), # drop masks tl.Map(tl.Dense(d_feature)))
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = tl.Serial( tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), strides, padding='SAME'), tl.BatchNorm(), tl.Relu(), tl.Conv(channels, (3, 3), padding='SAME')) shortcut = tl.NoOp() if not channel_mismatch else tl.Conv( channels, (3, 3), strides, padding='SAME') return tl.Residual(main, shortcut=shortcut)
def Residual(*layers, **unused_kwargs): """Constructs a residual version of layers, summing input to layers output.""" return tl.Serial(tl.Branch(tl.Serial(*layers), tl.NoOp()), tl.AddAll())