def encoder(embedded_source, source_mask): """Transformer encoder stack. Args: embedded_source: staxlayer variable: embedded source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode)) ) return stax.serial( embedded_source, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(), )
def DecoderLayer(feature_depth, feedforward_depth, num_heads, dropout, mode): """Transformer decoder layer. Args: feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) mode: str: 'train' or 'eval' Returns: init and apply. """ return stax.serial( stax.residual( # Self-attention block. stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask stax.MultiHeadedAttention(feature_depth, num_heads=num_heads, dropout=dropout, mode=mode), stax.Dropout(dropout, mode=mode) ), ResidualFeedForward(feature_depth, feedforward_depth, dropout, mode=mode) )
def Encoder(source, source_mask): """Transformer encoder stack. Args: source: staxlayer variable: raw source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode), ) return stax.serial( source, source_embedding_layer, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(), )
def decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual( stax.LayerNorm(), stax.FanOut(4), stax.parallel( stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # target attends to encoded source stax.residual( stax.LayerNorm(), stax.FanOut(4), stax.parallel( stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode))) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), )
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.Serial(stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, stax.Identity()), stax.FanInSum(), stax.Relu())
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.Serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum(), stax.Relu())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2 = filters def MakeMain(input_shape): # the number of output channels depends on the number of input channels return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm()) main = stax.shape_dependent(MakeMain) return stax.serial(stax.FanOut(2), stax.parallel(main, stax.Identity), stax.FanInSum, stax.Relu)
def policy_and_value_net(rng_key, batch_observations_shape, num_actions, bottom_layers=None): """A policy and value net function.""" # Layers. layers = [] if bottom_layers is not None: layers.extend(bottom_layers) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. layers.extend([stax.FanOut(), stax.Parallel( stax.Serial(stax.Dense(num_actions), stax.Softmax()), stax.Dense(1) )]) net_init, net_apply = stax.Serial(layers) _, net_params = net_init(rng_key, batch_observations_shape) return net_params, net_apply
def lambda_fun3(x, y, z, w, v): input_tree = _build_combinator_tree(tree_spec, (x, y, z)) return stax.serial(input_tree, stax.FanOut(3), stax.parallel(w, v, stax.Identity), stax.FanInSum)
def TransformerLM(vocab_size, # pylint: disable=invalid-name mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.1, max_len=2048): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding Returns: init and apply. """ keep_rate = 1.0 - dropout # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(keep_rate, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Single decoder layer decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode)) ) return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.Dropout(keep_rate, mode=mode), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )