def GRULM(vocab_size=256, d_model=512, n_layers=2, mode='train'): """Returns a GRU language model. Args: vocab_size (int, optional): Size of the vocabulary. Defaults to 256. d_model (int, optional): Depth of embedding (n_units in the GRU cell). Defaults to 512. n_layers (int, optional): Number of GRU layers. Defaults to 2. mode (str, optional): 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to "train". Returns: trax.layers.combinators.Serial: A GRU language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ ### START CODE HERE (Replace instances of 'None' with your code) ### model = tl.Serial( tl.ShiftRight(mode=mode), # Stack the ShiftRight layer tl.Embedding(vocab_size=vocab_size, d_feature=d_model), # Stack the embedding layer [ tl.GRU(n_units=d_model) for i in range(n_layers) ], # Stack GRU layers of d_model units keeping n_layer parameter in mind (use list comprehension syntax) tl.Dense(n_units=vocab_size), # Dense layer tl.LogSoftmax() # Log Softmax ) ### END CODE HERE ### return model
def GRULM(vocab_size=256, d_model=512, n_layers=1, mode='train'): """Returns an GRU language model. The input to the model is a tensor of tokens (ints). Args: vocab_size: int: vocab size d_model: int: depth of embedding (n_units in the RNN cell) n_layers: int: number of RNN layers mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, vocab_size), [tl.GRU(d_model) for _ in range(n_layers)], tl.Dense(vocab_size), tl.LogSoftmax() )
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] # Put the different blocks and functions together to be executed like in a stack return tl.Serial( tl.ShiftRight(mode=mode), positional_encoder, decoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def GRULM(vocab_size=256, d_model=512, n_layers=1, mode='train'): """Returns a GRU (gated recurrent unit) language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Embedding depth throughout the model. n_layers: Number of GRU layers. mode: If `'predict'`, use fast inference (and omit the right shift). Returns: A GRU language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(vocab_size, d_model), [tl.GRU(d_model) for _ in range(n_layers)], tl.Dense(vocab_size), )
def PositionLookupTransformerLM(vocab_size=128, d_model=256, d_ff=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial( tl.ShiftRight(), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), NewPositionalEncoding(positions=positions), [ DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers) ], PreservePosition(tl.LayerNorm()), tl.Dense(vocab_size), tl.LogSoftmax())
def test_shift_right(self): # Test shifts right on axis=1 layer = tl.ShiftRight() x = np.array([[[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]]]) y = layer(x) self.assertEqual(x.shape, y.shape) self.assertEqual( tl.to_list(y), [[[0, 0, 0], [9, 9, 9], [8, 8, 8], [7, 7, 7]], [[0, 0, 0], [99, 98, 97], [96, 95, 94], [93, 92, 91]]])
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size (int): vocab size. d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_layers (int): number of decoder layers. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). max_len (int): maximum symbol length for positional encoding. mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference. ff_activation (function): the non-linearity in feed-forward layer. Returns: trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ # Create stack (list) of decoder blocks with n_layers with necessary parameters decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] # Create the complete model as written in the figure return tl.Serial( # Use teacher forcing (feed output of previous step to current step) tl.ShiftRight(mode=mode), # Add embedding inputs and positional encoder PositionalEncoder(vocab_size, d_model, dropout, max_len, mode), # Add decoder blocks decoder_blocks, # Normalize layer tl.LayerNorm(), # Add dense layer of vocab_size (since need to select a word to translate to) # (a.k.a., logits layer. Note: activation already set by ff_activation) tl.Dense(vocab_size), # Get probabilities with Logsoftmax tl.LogSoftmax())
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs )
def test_shift_right_float(self): layer = tl.ShiftRight() x = np.array([[[9, 9, 9], [8, 8, 8], [7, 7, 7], [6, 6, 6]], [[99, 98, 97], [96, 95, 94], [93, 92, 91], [90, 89, 88]]]).astype(np.float32) x /= 2.0 self.assertEqual(x.dtype, np.float32) y = layer(x) self.assertEqual(y.dtype, np.float32) self.assertEqual(tl.to_list(y), [[[0.0, 0.0, 0.0], [4.5, 4.5, 4.5], [4.0, 4.0, 4.0], [3.5, 3.5, 3.5]], [[0.0, 0.0, 0.0], [49.5, 49.0, 48.5], [48.0, 47.5, 47.0], [46.5, 46.0, 45.5]]])
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. The input to the model is a tensor of tokens (ints). Args: vocab_size: int: vocab size d_model: int: depth of embedding (n_units in the RNN cell) n_layers: int: number of RNN layers rnn_cell: the RNN cell rnn_cell_d_state_multiplier: how many times is RNN cell state larger dropout: float: dropout rate (how much to drop out) mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ def MultiRNNCell(): """Multi-layer RNN cell.""" assert n_layers == 2 return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)) ) zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier ) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1), tl.Select([0], n_in=2), # Drop RNN state. tl.Dense(vocab_size), tl.LogSoftmax() )
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size (int): vocab size. d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_layers (int): number of decoder layers. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). max_len (int): maximum symbol length for positional encoding. mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference. ff_activation (function): the non-linearity in feed-forward layer. Returns: trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens nnn to activations over a vocab set. """ # Embedding inputs and positional encoder positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] encoder_blocks = [ Encoder(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] return tl.Serial(tl.ShiftRight(mode=mode), positional_encoder, encoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def TransformerLM(vocab_size=33300, embeddingDepth=512, depth=2048, n_layers=6, n_heads=8, dropout=0.1, maxLength=4096, mode='train', ffActivation=tl.Relu): # Create stack (list) of decoder blocks with n_layers with necessary parameters decoder_blocks = [ DecoderBlock(embeddingDepth, depth, n_heads, dropout, mode, ffActivation) for _ in range(n_layers) ] # Create the complete model as written in the figure return tl.Serial( tl.ShiftRight(mode=mode), PositionalEncoder(vocab_size, embeddingDepth, dropout, maxLength, mode), decoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Embedding depth throughout the model. n_layers: Number of RNN layers. rnn_cell: Type of RNN cell; must be a subclass of `Layer`. rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell state. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout. mode: If `'predict'`, use fast inference; if `'train'` apply dropout. Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? raise ValueError(f'Number of layers must be set to 2; instead got' f' {n_layers}.') def MultiRNNCell(): """Multi-layer RNN cell.""" return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)) ) zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier ) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1), tl.Select([0], n_in=2), # Drop RNN state. tl.Dense(vocab_size), )
def LSTMSeq2SeqAttn(input_vocab_size=256, target_vocab_size=256, d_model=512, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=1, attention_dropout=0.0, mode='train'): """Returns an LSTM sequence-to-sequence model with attention. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. The model works as follows: * Input encoder runs on the input tokens and creates activations that are used as both keys and values in attention. * Pre-attention decoder runs on the targets and creates activations that are used as queries in attention. * Attention runs on the queries, keys and values masking out input padding. * Decoder runs on the result, followed by a cross-entropy loss. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. target_vocab_size: Target vocabulary size. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. n_encoder_layers: Number of LSTM layers in the encoder. n_decoder_layers: Number of LSTM layers in the decoder after attention. n_attention_heads: Number of attention heads. attention_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an attention block. mode: If `'predict'`, use fast inference. If `'train'`, each attention block will include dropout; else, it will pass all values through unaltered. Returns: An LSTM sequence-to-sequence model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ input_encoder = tl.Serial( tl.Embedding(input_vocab_size, d_model), [tl.LSTM(d_model) for _ in range(n_encoder_layers)], ) pre_attention_decoder = tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(target_vocab_size, d_model), tl.LSTM(d_model), ) def PrepareAttentionInputs(): """Layer that prepares queries, keys, values and mask for attention.""" def F(encoder_activations, decoder_activations, input_tokens): keys = values = encoder_activations queries = decoder_activations # Mask is 1 where inputs are not padding (0) and 0 where they are padding. mask = (input_tokens != 0) # We need to add axes to the mask for attention heads and decoder length. mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) mask = mask.astype(jnp.float32) return queries, keys, values, mask return tl.Fn('PrepareAttentionInputs', F, n_out=4) return tl.Serial( # in-toks, target-toks tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks tl.Parallel(input_encoder, pre_attention_decoder), PrepareAttentionInputs(), # q, k, v, mask, target-toks tl.Residual( tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode) ), # decoder-vecs, mask, target-toks tl.Select([0, 2]), # decoder-vecs, target-toks [tl.LSTM(d_model) for _ in range(n_decoder_layers)], tl.Dense(target_vocab_size), tl.LogSoftmax() )
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn(lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn(lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax() )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) d_emb = d_model if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) d_emb //= 2 elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_emb, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( concatenate_input_chunks, tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ # tok_e mask tok_d ..... tl.PaddingMask(), tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]), # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def ConfigurableTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, attention_type=tl.CausalAttention, pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` loss_sparsity_type: string, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size attention_type: The attention layer to use for the decoder part. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), PositionalEncoder(mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) ] # pylint: disable=g-complex-comprehension decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, attention_type) for i in range(n_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs tl.SparseDenseWithOptions( # vecs vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode), )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding attention_type: class: attention class to use, such as SelfAttention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ positional_encoding = ct.PositionalEncoder(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) positional_embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode) return tl.Serial( tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks), tl.Concatenate(), # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter dense_loss_layer, )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs, use_bfloat16=use_bfloat16)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, use_bfloat16=use_bfloat16, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget) ] if not reversible_encoder: encoder += [ tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def ConfigurableTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None, enc_dec_attention_sparsity=0): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. encoder_decoder_attention_type: The attention layer to use for the encoder-decoder attention. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_decoder_attention_type, enc_dec_attention_sparsity) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.SparseDenseWithOptions( # vec_d ..... output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode), )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. mode: str: 'train' or 'eval' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2) else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) concatenate_output_chunks = [] positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( concatenate_input_chunks, tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode) for _ in range(n_layers) ] + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), concatenate_output_chunks, )
def _shift_right(n): # pylint: disable=invalid-name return [tl.ShiftRight()] * n
def FunnelTransformerLM(vocab_size, d_model=512, d_ff=2048, vanilla_layers=(0, 1), shorten_factors=(3,), n_funnel_blocks=(6,), n_heads=8, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.FastGelu): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level Transformer decoder layers before and after shortening. shorten_factors: by how much to shorten at each step - tuple of arbitrary length denoting by how much shorten at each pooling stage. n_funnel_blocks: number of Transformer decoder blocks after each stage of pooling - tuple of the same length as `shorten_factors`. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: str: 'train' or 'eval'. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ assert mode != 'predict' # For now, 'predict' mode is unsupported. assert len(n_funnel_blocks) == len(shorten_factors) token_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] context_bias_layer, location_bias_layer = _get_rel_att_inputs(d_model, n_heads) n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers def create_decoder_blocks(n_layers, total_pooling): # pylint: disable=invalid-name decoder_blocks = [ # pylint: disable=g-complex-comprehension _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling) for _ in range(n_layers)] return decoder_blocks + [tl.LayerNorm()] total_pooling_acc = 1 pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, total_pooling=1) funnel_blocks = [] for shorten_factor, block_len in zip(shorten_factors, n_funnel_blocks): funnel_blocks = funnel_blocks + [_FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=shorten_factor, resampler_fn=_DownsamplerLM)] total_pooling_acc *= shorten_factor funnel_blocks = funnel_blocks + create_decoder_blocks(block_len, total_pooling_acc) upsampling_layer = _FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=total_pooling_acc, resampler_fn=_UpsamplerLM) conv_layer = tl.Serial( tl.CausalConv(d_model, total_pooling_acc), ff_activation() ) post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, total_pooling=1) # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks token_encoder, # vecs pre_decoder_blocks, # vecs tl.Dup(), tl.ShiftRight(n_positions=total_pooling_acc - 1), funnel_blocks, tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), upsampling_layer, tl.LayerNorm(), tl.Concatenate(), conv_layer, post_decoder_blocks, tl.Dense(vocab_size), # vecs )