def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # target attends to encoded source stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode) ) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), )
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2 = filters def MakeMain(input_shape): # the number of output channels depends on the number of input channels return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm()) main = stax.shape_dependent(MakeMain) return stax.serial(stax.FanOut(2), stax.parallel(main, stax.Identity), stax.FanInSum, stax.Relu)
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def TransformerLM(vocab_size, feature_depth=512, feedforward_depth=2048, num_layers=6, num_heads=8, dropout=0.1, max_len=2048, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_layers: int: number of encoder/decoder layers num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: init and apply. """ return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.Dropout(dropout, mode=mode), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.repeat( DecoderLayer( feature_depth, feedforward_depth, num_heads, dropout, mode), num_layers), stax.LayerNorm(), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def lambda_fun3(x, y, z, w, v): input_tree = _build_combinator_tree(tree_spec, (x, y, z)) return stax.serial(input_tree, stax.FanOut(3), stax.parallel(w, v, stax.Identity), stax.FanInSum)
def WideResnetGroup(n, channels, strides=(1, 1)): blocks = [] blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)] for _ in range(n - 1): blocks += [WideResnetBlock(channels, (1, 1))] return stax.serial(*blocks)
def generator(encoded_target): return stax.serial( encoded_target, stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def TransformerEncoder(mode='train', # pylint: disable=invalid-name num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9): """Transformer Encoder Stack. Args: mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention Returns: A staxlayer for implementing a raw Transformer encoder stack. No embedding or positional signals are added by this layer. """ # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) @stax.Lambda def encoder(embedded_source, source_mask): """Transformer encoder stack. Args: embedded_source: staxlayer variable: embedded source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( embedded_source, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(feature_depth), ) return encoder
def Transformer(source_vocab_size, # pylint: disable=invalid-name target_vocab_size, mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.9, shared_embedding=True, max_len=200, return_evals=False): """Transformer model. Args: source_vocab_size: int: source vocab size target_vocab_size: int: target vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate - Stax follows TF's KEEP probability convention shared_embedding: bool: specify whether source/target embeddings are tied. max_len: int: maximum symbol length for positional encoding return_evals: bool: whether to generate decode-time evaluation functions Returns: A namedtuple containing model 'init' and 'apply' functions for training and the 'evals' functions that itself returns a namedtuple containing evaluation functions for the trained encoder, decoder, and generator substax. """ # Input embedding and positional encoding inject_position = stax.serial( stax.PositionalEncoding(feature_depth, max_len=max_len), stax.Dropout(dropout, mode=mode) ) if shared_embedding: assert source_vocab_size == target_vocab_size # Weight-shared Embedding embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size)) source_embedding_layer = stax.serial(embedding, inject_position) target_embedding_layer = source_embedding_layer else: source_embedding = stax.Embedding(feature_depth, source_vocab_size) target_embedding = stax.Embedding(feature_depth, target_vocab_size) source_embedding_layer = stax.serial(source_embedding, inject_position) target_embedding_layer = stax.serial(target_embedding, inject_position) # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(dropout, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Encoder @stax.Lambda def encoder(source, source_mask): """Transformer encoder stack. Args: source: staxlayer variable: raw source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( source, source_embedding_layer, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(feature_depth), ) # Decoder @stax.Lambda def decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # target attends to encoded source stax.residual(stax.LayerNorm(feature_depth), stax.multiplex(stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(feature_depth), feed_forward, stax.Dropout(dropout, mode=mode)) ) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(feature_depth), ) # The Transformer @stax.Lambda def transformer(source, target, source_mask, target_mask, memory_mask): encoded_source = encoder(source, source_mask) return decoder(encoded_source, target, target_mask, memory_mask) # Finally, bind the generator transform to use later for inference. @stax.Lambda def generator(encoded_target): return stax.serial( encoded_target, stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax ) # Model-Building and Evaluation Functions # Get entire model's init and apply pair top_init, top_apply = generator(transformer) # By default act as a normal Stax constructor and emit an (init, apply) pair. if not return_evals: return (top_init, top_apply) else: # Inference-time function for binding trained params to model and returning # the python-bound sub-expressions for evaluation and sequence generation. def make_namedtuple(**kwargs): return collections.namedtuple('Model', kwargs.keys())(**kwargs) def get_evals(params): # We need to feed _concrete_ trained parameters through the network once. # Otherwise the bound parameters point to abstract tracer values. # The inputs don't matter. fake_inputs = 5 * (np.ones((1), dtype=np.int32),) fake_key = random.PRNGKey(1) top_apply(params, fake_inputs, rng=fake_key) # We can now return eval functions from the bound pieces of the model. return make_namedtuple( encoder=stax.make_apply_fun(encoder), generator=stax.make_apply_fun(generator), decoder=stax.make_apply_fun(decoder), ) # We return the functions needed to train and evaluate the Transformer. return make_namedtuple( init=top_init, apply=top_apply, evals=get_evals, )
def TransformerLM(vocab_size, # pylint: disable=invalid-name mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.1, max_len=2048): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding Returns: init and apply. """ keep_rate = 1.0 - dropout # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=keep_rate, mode=mode) feed_forward = stax.serial( stax.Dense(feedforward_depth, W_init=stax.xavier_uniform()), stax.Relu, stax.Dropout(keep_rate, mode=mode), stax.Dense(feature_depth, W_init=stax.xavier_uniform()) ) # Single decoder layer decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value stax.CausalMask(axis=-2)), # attention mask multi_attention, stax.Dropout(keep_rate, mode=mode)), # feed-forward stax.residual(stax.LayerNorm(), feed_forward, stax.Dropout(keep_rate, mode=mode)) ) return stax.serial( stax.ShiftRight(), stax.Embedding(feature_depth, vocab_size), stax.Dropout(keep_rate, mode=mode), stax.PositionalEncoding(feature_depth, max_len=max_len), stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), stax.Dense(vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax )
def Transformer(source_vocab_size, target_vocab_size, mode='train', num_layers=6, feature_depth=512, feedforward_depth=2048, num_heads=8, dropout=0.1, shared_embedding=True, max_len=200, return_evals=False): """Transformer model. Args: source_vocab_size: int: source vocab size target_vocab_size: int: target vocab size mode: str: 'train' or 'eval' num_layers: int: number of encoder/decoder layers feature_depth: int: depth of embedding feedforward_depth: int: depth of feed-forward layer num_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) shared_embedding: bool: specify whether source/target embeddings are tied. max_len: int: maximum symbol length for positional encoding return_evals: bool: whether to generate decode-time evaluation functions Returns: A namedtuple containing model 'init' and 'apply' functions for training and the 'evals' functions that itself returns a namedtuple containing evaluation functions for the trained encoder, decoder, and generator substax. """ # Input embedding and positional encoding inject_position = stax.serial( stax.Dropout(dropout, mode=mode), stax.PositionalEncoding(feature_depth, max_len=max_len) ) if shared_embedding: assert source_vocab_size == target_vocab_size # Weight-shared Embedding embedding = stax.Share(stax.Embedding(feature_depth, source_vocab_size)) source_embedding_layer = stax.serial(embedding, inject_position) target_embedding_layer = source_embedding_layer else: source_embedding = stax.Embedding(feature_depth, source_vocab_size) target_embedding = stax.Embedding(feature_depth, target_vocab_size) source_embedding_layer = stax.serial(source_embedding, inject_position) target_embedding_layer = stax.serial(target_embedding, inject_position) # Multi-headed Attention and Feed-forward layers multi_attention = stax.MultiHeadedAttention( feature_depth, num_heads=num_heads, dropout=dropout, mode=mode) # Encoder @stax.Lambda def Encoder(source, source_mask): """Transformer encoder stack. Args: source: staxlayer variable: raw source sequences source_mask: staxlayer variable: self-attention mask Returns: Staxlayer variable that outputs encoded source. """ encoder_layer = stax.serial( # input attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value source_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode), ) return stax.serial( source, source_embedding_layer, stax.repeat(encoder_layer, num_layers), stax.LayerNorm(), ) # Decoder @stax.Lambda def Decoder(memory, target, target_mask, memory_mask): """Transformer decoder stack. Args: memory: staxlayer variable: encoded source sequences target: staxlayer variable: raw target sequences target_mask: staxlayer variable: self-attention mask memory_mask: staxlayer variable: memory attention mask Returns: Staxlayer variable that outputs encoded source. """ decoder_layer = stax.serial( # target attends to self stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query stax.Identity, # key stax.Identity, # value target_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # target attends to encoded source stax.residual(stax.LayerNorm(), stax.FanOut(4), stax.parallel(stax.Identity, # query memory, # key memory, # value memory_mask), # attention mask multi_attention, stax.Dropout(dropout, mode=mode)), # feed-forward ResidualFeedForward( feature_depth, feedforward_depth, dropout, mode=mode) ) return stax.serial( target, target_embedding_layer, stax.repeat(decoder_layer, num_layers), stax.LayerNorm(), ) # The Transformer @stax.Lambda def transformer(source, target, source_mask, target_mask, memory_mask): # pylint: disable=invalid-name encoded_source = Encoder(source, source_mask) return Decoder(encoded_source, target, target_mask, memory_mask) # Finally, bind the generator transform to use later for inference. @stax.Lambda def Generator(encoded_target): return stax.serial( encoded_target, stax.Dense(target_vocab_size, W_init=stax.xavier_uniform()), stax.LogSoftmax ) # Model-Building and Evaluation Functions # Get entire model's init and apply pair top_init, top_apply = Generator(transformer) # By default act as a normal Stax constructor and emit an (init, apply) pair. if not return_evals: return (top_init, top_apply) else: raise ValueError('inference in this model is still a work in progress')
def lambda_fun2(x, y, z, w, v): input_tree = _build_combinator_tree(tree_spec, (x, y, z)) return stax.serial(input_tree, stax.multiplex(w, stax.Identity, v), stax.FanInSum)