def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding attention_type: class: attention class to use, such as SelfAttention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ d_emb = d_model if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) d_emb //= 2 elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_emb, vocab_size), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks), tl.Concatenate(), # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding attention_type: class: attention class to use, such as SelfAttention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ positional_encoding = ct.PositionalEncoder( mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) positional_embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode) return tl.Serial( tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks), tl.Concatenate(), # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter dense_loss_layer, )
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=tl.SelfAttention, pos_type=None, pos_axial_shape=(), pos_d_axial_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape :: [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding attention_type: class: attention class to use, such as SelfAttention. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference positional_encoding = ct.PositionalEncoder( mode, dropout, max_len, pos_type, pos_axial_shape, pos_d_axial_embs) positional_embedder = [ tl.Embedding(vocab_size, d_embedding), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_positions=shorten_factor - 1), tl.Fn('Shorten', lambda x: jnp.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn('ProlongBack', lambda x: jnp.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), )
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, attention_type): """Returns a list of layers implementing a Transformer encoder-decoder block. The input is a triple (decoder_activations, mask, encoder_activiations) where the mask is created from the original input token IDs to prevent attending to the padding part of the encoder. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size attention_type: The attention layer to use. Returns: A list of layers which maps triples (decoder_activations, mask, encoder_activations) to triples of the same sort. """ def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) # TODO(afrozm): This layer isn't configurable because: We currently don't have # any alternative for it (LSH cannot do it fundamentally, that's why we have # NoEncDec models, and local attention doesn't make sense in the general # setting where we don't know what in input is local to what in output; # some variants of FAVOR can do it, so maybe in the future, # but we don't have them yet). attention_qkv = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) causal_attention = ApplyAttentionLayer( attention_type, d_model, n_heads, d_model // n_heads, d_model // n_heads, causal=True, masked=True, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, ff_sparsity_type) return [ # vec_d masks vec_e tl.Residual( tl.LayerNorm(), # vec_d ..... ..... causal_attention, # vec_d ..... ..... _Dropout(), # vec_d ..... ..... ), tl.Residual( tl.LayerNorm(), # vec_d ..... ..... tl.Select([0, 2, 2, 1, 2]), # vec_d vec_e vec_e masks vec_e attention_qkv, # vec_d masks vec_e _Dropout(), # vec_d masks vec_e ), tl.Residual(feed_forward # vec_d masks vec_e ), ]
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train', axial_pos_shape=None, d_axial_pos_embs=None, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs) ) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_decoder_layers)] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... )
def ConfigurableTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, attention_type=tl.CausalAttention, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size attention_type: The attention layer to use for the decoder part. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), PositionalEncoder(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) ] # pylint: disable=g-complex-comprehension decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, attention_type) for i in range(n_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs )
def ConfigurableTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. encoder_decoder_attention_type: The attention layer to use for the encoder-decoder attention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_decoder_attention_type) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder merged with an N-way categorization head. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) # Positional encodings are not shared between encoder and decoder. # Since encoder doesn't run stepwise, we do not use predict mode there. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=encoder_mode) ] out_encoder = out_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=mode) ] if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... )
def TransformerDecoder(vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a Transformer decoder. This model maps sequential inputs to sequential outputs: - input if `vocab_size` is specified: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - input if `vocab_size` is None: rank 3 tensor representing a batch of activation vectors; shape is (batch_size, sequence_length, `d_model`). - output: rank 3 tensor with shape (batch_size, sequence_length, `d_model`). The model uses causal attention and does *not* shift the input to the right. Thus, the output for position `t` is based on inputs up to and including position `t`. Args: vocab_size: If specified, gives the input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. If None, indicates that the model expects as input floating point vectors, each with `d_model` components. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each decoder block. n_layers: Number of decoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each decoder block; must be an activation-type subclass of `Layer`. Returns: If `vocab_size` is defined: a Transformer model that maps strings (conveyed via token IDs) to sequences of activation vectors. If `vocab_size` is None: a Transformer model that maps sequences of activation vectors to sequences of activation vectors. """ positional_encoder = [(tl.Embedding(vocab_size, d_model) if vocab_size is not None else tl.Dense(d_model)), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs )
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, n_layers_forget=0, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity n_layers_forget: how often to have a forgetting block between layers mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) def PositionalEnc(mode): return PositionalEncoding(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [PositionalEnc(encoder_mode)] out_encoder = out_embedder + [PositionalEnc(mode)] if output_vocab_size is None: output_vocab_size = input_vocab_size # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # vec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([1, 0]), # vec_e vec_d mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if math.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch([], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if math.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, use_bfloat16=False, ff_sparsity_type='1inN'): """Feed-Forward block with all the options. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, tuple or string; if not 0, use sparse feed-forward block with this sparsity mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. use_bfloat16: whether to use bfloat16 for weights (default: False). ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` Returns: A list of layers which maps vectors to vectors. """ if ff_sparsity and ff_sparsity_type == '1inN': temperature, quant_prob = 0.1, 0.3 if isinstance(ff_sparsity, str): # This is hacky but used to pass ff_sparsity in yaml sweep files. ff_sparsity = [(float(x) if '.' in x else int(x)) for x in ff_sparsity.split()] if isinstance(ff_sparsity, (list, tuple)): if len(ff_sparsity) == 2: n_elements_in_block, d_lowrank = ff_sparsity else: n_elements_in_block, d_lowrank, temperature, quant_prob = ff_sparsity else: assert isinstance(ff_sparsity, int) n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity ff = tl.SparseFF( d_ff, n_elements_in_block=n_elements_in_block, d_lowrank=d_lowrank, temperature=temperature, quant_prob=quant_prob, mode=mode) elif ff_sparsity and ff_sparsity_type == 'Block': ff = tl.BlockSparseFF(d_ff, num_experts=ff_sparsity, mode=mode), else: ff = _FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, use_bfloat16, mode) res = [tl.LayerNorm(), ff, tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] if ff_chunk_size > 0: res = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(res), ff_chunk_size)) if ff_use_sru: sru = [tl.Dense(32)] + [tl.SRU(32) for _ in range(ff_use_sru)] res = tl.Residual(sru + [tl.Dense(d_model)], res) return [res]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys mode: str: 'train' or 'eval' Returns: the layer. """ if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def BERT( d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(vocab_size, d_model) type_embeddings = tl.Embedding(type_vocab_size, d_model) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel(word_embeddings, type_embeddings, [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model)] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT(embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. mode: str: 'train' or 'eval' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout) positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( concatenate_input_chunks, tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map( [ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, ff_sparsity_type='1inN'): """Feed-Forward block with all the options. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` Returns: A list of layers which maps vectors to vectors. """ if ff_use_sru: return [tl.SRU(d_model) for _ in range(ff_use_sru)] elif ff_sparsity and ff_sparsity_type == '1inN': if isinstance(ff_sparsity, tuple): n_elements_in_block, d_lowrank = ff_sparsity else: assert isinstance(ff_sparsity, int) n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity ff = tl.SparseFF(d_ff, n_elements_in_block=n_elements_in_block, d_lowrank=d_lowrank, mode=mode) if ff_chunk_size < 1: chunked_ff = ff else: chunked_ff = tl.BatchLeadingAxes( tl.Chunk(tl.Serial(ff), ff_chunk_size)) return [ tl.LayerNorm(), chunked_ff, tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) ] elif ff_sparsity and ff_sparsity_type == 'Block': return [ tl.LayerNorm(), tl.BlockSparseFF(d_ff, num_experts=ff_sparsity, mode=mode), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) ] else: return [ ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, ff_chunk_size, mode) ]
def FunnelTransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), strides=(2,), separate_cls=True): """Returns a Funnel Encoder. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] encoder_blocks = [] n_encoder_segments = len(encoder_segment_lengths) for i in range(n_encoder_segments): # Building i'th segment for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) # If not last segment, add funnel block if i != n_encoder_segments - 1: encoder_blocks.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls)) cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1) # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. cls_pooling, # cls tl.Dense(n_classes), # cls )
def DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, attention_type, n_attention_layers=1, n_feedforward_layers=1): """Returns a list of layers that implements a Transformer decoder block. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size attention_type: The attention layer to use. n_attention_layers: how many residual causal attention layers should we have before the feed-forward block (default: 1, the standard block) n_feedforward_layers: how many FFNN layers should we have (default 1). Returns: A list of layers that maps an activation tensor to an activation tensor. """ # pylint: disable=g-complex-comprehension causal_attentions = [ ApplyAttentionLayer(attention_type, d_model, n_heads, d_model // n_heads, d_model // n_heads, causal=True, masked=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) for _ in range(n_attention_layers) ] residual_attentions = [ tl.Residual( tl.LayerNorm(), causal_attentions[i], tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)) for i in range(n_attention_layers) ] feed_forwards = [ tl.Residual( FeedForwardWithOptions(d_model, d_ff, dropout, dropout_shared_axes, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, ff_sparsity_type)) for _ in range(n_feedforward_layers) ] # pylint: enable=g-complex-comprehension return residual_attentions + feed_forwards
def FunnelTransformer(vocab_size, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_decoder_blocks=2, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), separate_cls=True): """Returns a Full Funnel Transformer, that can be used for example for BERT. This model outputs token-level categorical distributions over all vocab: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions over `vocab_size` categories for each token; shape is (batch_size, sequence_length, vocab_size). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_decoder_blocks: Number of transformer blocks in the upsampling decoder. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] n_encoder_segments = len(encoder_segment_lengths) encoder_blocks_before_first_pooling = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(encoder_segment_lengths[0])] encoder_blocks_from_first_pooling = [] for i in range(1, n_encoder_segments): # Building i'th segment # Add funnel block between segments encoder_blocks_from_first_pooling.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size=pool_size, strides=pool_size, separate_cls=separate_cls)) for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks_from_first_pooling.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_decoder_blocks)] total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1) # Assemble and return the model. return tl.Serial( # toks tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks_before_first_pooling, # vecs masks tl.Select([0, 1, 0, 1]), # vecs masks residual = vecs old_masks encoder_blocks_from_first_pooling, # vecs masks residual masks tl.Select([0, 2, 3]), # vecs residual masks tl.Parallel( # residual from first segment is taken before # normalization, so apply it now None, tl.LayerNorm(), None), # vecs norm(residual) masks _Upsampler(total_pool_size, separate_cls), # vecs masks decoder_blocks, tl.Select([0], n_in=2), # vecs tl.LayerNorm(), tl.Dense(vocab_size), )
def Transformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, encoder_attention_type=tl.Attention, n_encoder_attention_layers=1, decoder_attention_type=tl.CausalAttention, n_decoder_attention_layers=2, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. n_encoder_attention_layers: int, within each encoder block, how many attention layers to have. decoder_attention_type: The attention layer to use for the encoder-decoder attention. n_decoder_attention_layers: int, within each decoder block, how many attention layers to have. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type, n_encoder_attention_layers) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension decoder_blocks = [ ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, decoder_attention_type, n_decoder_attention_layers) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e mask_e tok_e tok_d tok_d encoder, # vec_e mask_e tok_e tok_d tok_d # Simple encoder mask, doesn't contain extra dims. tl.Select([2, 0, 2], n_in=3), # tok_e vec_e tok_e tok_d tok_d tl.Fn( 'EncoderMask', # mask_e vec_e tok_e tok_d tok_d lambda x: x != 0, n_out=1), # Decode. tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d out_encoder, # svec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder. tl.Select([1, 0]), # vec_e svec_d mask_e tok_e tok_d ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d )
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = configurable_transformer.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = configurable_transformer.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, n_decoder_attention_layers=2, use_bfloat16=False, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs, use_bfloat16=use_bfloat16) ) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, use_bfloat16=use_bfloat16, mode=mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model, use_bfloat16=use_bfloat16), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel(in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # vec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given encoder mask. tl.Select([1, 0]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d _ReversibleSerialForget(decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder(input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def TransformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len), ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ transformer._EncoderBlock( d_model, d_ff, n_heads, dropout, # pylint: disable=protected-access dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [ transformer._DecoderBlock( d_model, d_ff, n_heads, dropout, # pylint: disable=protected-access dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] # pylint: disable=protected-access # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e mask_e tok_e tok_d tok_d encoder, # vec_e mask_e tok_e tok_d tok_d # Simple encoder mask, doesn't contain extra dims. tl.Select([2, 0, 2], n_in=3), # tok_e vec_e tok_e tok_d tok_d transformer._MaskOfRightShiftedArray( n_shifts=0), # mask_e vec_e tok_e tok_d tok_d # Decode. tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch([], transformer._MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder. tl.Select([2, 0, 3, 1]), # vec_e svec_d mask_e mask_d tok_e tok_d transformer._ConcatWithPadding(), # vec_ed tok_e tok_d # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d transformer._StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def ReZeroTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a ReZero transformer model. This model expects an input pair: source, target. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A ReZero transformer model as a layer that maps from a source, target pair to activations over a vocab set. """ def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) # Positional encoding are not shared between encoder and decoder. # Since encoder doesn't run stepwise, we do not use predict mode there. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=encoder_mode) ] out_encoder = out_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=mode) ] if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... )
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, center_layernorm=True, use_bfloat16=False, use_two_swaps_per_block=True, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False) use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder block, otherwise use only one swap. mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. if attention.n_out == 2: def reshape_mask(mask): return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) attention = tl.Serial( tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2), attention, tl.Select([0], n_in=2) ) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=attention, name='ReversibleHalfResidualEncoderAttn' ) feed_forward = ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16) encoder_block = [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward, name='ReversibleHalfResidualEncoderFF'), ] if use_two_swaps_per_block: encoder_block.append(tl.ReversibleSwap()) return encoder_block
# ## Layers can have Weights # Some layer types include mutable weights and biases that are used in computation and training. Layers of this type require initialization before use. # # For example the `LayerNorm` layer calculates normalized data, that is also scaled by weights and biases. During initialization you pass the data shape and data type of the inputs, so the layer can initialize compatible arrays of weights and biases. # In[8]: # Uncomment any of them to see information regarding the function # help(tl.LayerNorm) # help(shapes.signature) # In[9]: # Layer initialization norm = tl.LayerNorm() # You first must know what the input data will look like x = np.array([0, 1, 2, 3], dtype="float") # Use the input data signature to get shape and type for initializing weights and biases norm.init( shapes.signature(x) ) # We need to convert the input datatype from usual tuple to trax ShapeDtype print("Normal shape:", x.shape, "Data Type:", type(x.shape)) print("Shapes Trax:", shapes.signature(x), "Data Type:", type(shapes.signature(x))) # Inspect properties print("-- Properties --") print("name :", norm.name)