def PositionalEncoding(mode, dropout=None, max_len=None, axial_pos_shape=None, d_axial_pos_embs=None): """Returns the positional encoding layer depending on the arguments.""" if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ DecoderBlock(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] # Put the different blocks and functions together to be executed like in a stack return tl.Serial( tl.ShiftRight(mode=mode), positional_encoder, decoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def PositionalEncoder(vocab_size, d_model, dropout, max_len, mode): """Returns a list of layers that: 1. takes a block of text as input, 2. embeds the words in that text, and 3. adds positional encoding, i.e. associates a number in range(max_len) with each word in each sentence of embedded input text The input is a list of tokenized blocks of text Args: vocab_size (int): vocab size. d_model (int): depth of embedding. dropout (float): dropout rate (how much to drop out). max_len (int): maximum symbol length for positional encoding. mode (str): 'train' or 'eval'. """ # Embedding inputs and positional encoder return [ # Add embedding layer of dimension (vocab_size, d_model) tl.Embedding(vocab_size, d_model), # Use dropout with rate and mode specified tl.Dropout(rate=dropout, mode=mode), # Add positional encoding layer with maximum input length and mode specified tl.PositionalEncoding(max_len=max_len, mode=mode) ]
def TransformerEncoder(vocab_size=vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, EncoderBlock=EncoderBlock): """ Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size (int): vocab size. Defaults to vocab_size. n_classes (int): how many classes on output. Defaults to 10. d_model (int): depth of embedding. Defaults to 512. d_ff (int): depth of feed-forward layer. Defaults to 2048. n_layers (int): number of encoder/decoder layers. Defaults to 6. n_heads (int): number of attention heads. Defaults to 8. dropout (float): dropout rate (how much to drop out). Defaults to 0.1. dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None. max_len (int): maximum symbol length for positional encoding. Defaults to 2048. mode (str): 'train' or 'eval'. Defaults to 'train'. ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu. EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock. Returns: trax.layers.combinators.Serial: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] # repeatation of Encoder block upto number of layers encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_layers) ] # Encoder Model return tl.Serial( tl.Branch( positional_encoder, tl.PaddingMask(), ), encoder_blocks, tl.Select([0], n_in=2), tl.LayerNorm(), tl.Mean(axis=1), tl.Dense(n_classes), tl.LogSoftmax(), )
def SkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_qk=False, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys in decoder attention max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] return tl.Serial( tl.ShiftRight(mode=mode), embedder, SkippingSerial( [ transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, i, mode, ff_activation) for i in range(n_layers) ], mode=mode), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def TransformerDecoder(vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, d_attention_key=None, d_attention_value=None, attention_type=tl.DotProductCausalAttention, dropout=0.1, share_qk=False, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer decoder model. The input to the model is either continuous or discrete - controlled by vocab_size. Does not shift the input to the right, i.e. the output for timestep t is based on inputs up to timestep t inclusively. Args: vocab_size: int or None: vocab size if running on discrete input, None otherwise. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads d_attention_key: int: depth of key vector for each attention head (default is d_model // n_heads) d_attention_value: int: depth of value vector for each attention head (default is d_model // n_heads) attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: bool, whether to share queries and keys in decoder attention max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer decoder as a layer that maps from a continuous or discrete tensor to a continuous tensor. """ positional_encoder = [ (tl.Embedding(d_model, vocab_size) if vocab_size is not None else tl.Dense(d_model)), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len)] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, d_attention_key, d_attention_value, attention_type, dropout, share_qk, i, mode, ff_activation) for i in range(n_layers)] # Assemble and return the model. return tl.Serial( # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs )
def PositionalEncoder(mode, dropout=None, max_len=None, pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, use_bfloat16=False): """Returns the positional encoding layer depending on the arguments. Args: mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the embedding block. max_len: Maximum symbol length for positional encoding. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. pos_start_from_zero_prob: how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize). pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. use_bfloat16: If `True`, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues. Returns: A layer that will do the positional encoding. """ if not pos_type: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16, start_from_zero_prob=pos_start_from_zero_prob, max_offset_to_add=pos_max_offset_to_add, mode=mode) elif pos_type == 'sin-cos': positional_encoding = tl.SinCosPositionalEncoding(mode=mode) elif pos_type == 'fixed-base': positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif pos_type == 'infinite': positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif pos_type == 'infinite-affine': positional_encoding = tl.InfinitePositionalEncoding() elif pos_type == 'time-bin': positional_encoding = tl.TimeBinPositionalEncoding() elif pos_type == 'no': positional_encoding = tl.Serial() # no positional encoding at all else: # TODO(lukaszkaiser): name this type and check for the correct name assert pos_d_axial_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=pos_axial_shape, d_embs=pos_d_axial_embs, dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def test_predict_equals_eval(self): x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) self.assertEqual(x.shape, (1, 4, 2)) layer_eval = tl.PositionalEncoding(max_len=8, d_feature=4, mode='eval') layer_eval.init(shapes.signature(x)) output_eval = layer_eval(x) layer_predict = tl.PositionalEncoding(max_len=8, d_feature=4, mode='predict') layer_predict.init(shapes.signature(x)) layer_predict.weights = layer_eval.weights output_predict = layer_predict(x) self.assertTrue(np.array_equal(output_eval, output_predict))
def PositionalEncoder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len), ]
def _Encoder(): encoder = tl.Serial( in_embedder, _Dropout(), tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), [_EncBlock() for _ in range(n_encoder_layers)], tl.LayerNorm(), ) return tl.Cache(encoder) if mode == 'predict' else encoder
def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ]
def test_predict(self): layer = tl.PositionalEncoding(max_len=8) x = np.array([[[2.0, 3.0], [1.0, 2.0], [0.0, 1.0], [3.0, 4.0]]]) self.assertEqual(x.shape, (1, 4, 2)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (1, 4, 2)) layer = tl.PositionalEncoding(max_len=8, mode='predict') layer.init(shapes.signature(x[:, :1, :])) y0 = layer(x[:, :1, :]) # just the first token self.assertEqual(y0.shape, (1, 1, 2)) self.assertTrue(np.array_equal(y0, y[:, :1, :])) y1 = layer(x[:, 1:3, :]) # now the next 2 tokens self.assertEqual(y1.shape, (1, 2, 2)) self.assertTrue(np.array_equal(y1, y[:, 1:3, :])) y2 = layer(x[:, 3:4, :]) # final one token self.assertEqual(y2.shape, (1, 1, 2)) self.assertTrue(np.array_equal(y2, y[:, 3:4, :]))
def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ]
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs )
def PositionalEncoder(vocab_size, mode): # tokens --> vectors #TODO(kitaev): axial positional encoding is better for very long sequences. return [ (tl.Embedding(vocab_size, d_model) if vocab_size is not None else tl.Dense(d_model, use_bias=False)), tl.Dropout(rate=dropout, shared_axes= [-2] , mode=mode), tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode), ]
def PositionalEncoder(mode, dropout=None, max_len=None, axial_pos_shape=None, d_axial_pos_embs=None, use_bfloat16=False): """Returns the positional encoding layer depending on the arguments. Args: mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the embedding block. max_len: Maximum symbol length for positional encoding. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. use_bfloat16: If `True`, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues. Returns: A layer that will do the positional encoding. """ if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode, use_bfloat16=use_bfloat16) elif axial_pos_shape == 'sin-cos': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.SinCosPositionalEncoding(mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def PositionalEncoder(vocab_size): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax # Transformer, but may not be the right option in general. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=0.0, mode=mode) return [ tl.Embedding(d_model, vocab_size), # TODO(kitaev): BroadcastedDropout? tl.Dropout(rate=dropout, mode=mode), positional_encoding, ]
def TransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs tl.Dense(vocab_size), # vecs tl.LogSoftmax(), # vecs )
def ReZeroTransformerDecoder(vocab_size=None, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a ReZero transformer decoder model. The input to the model is either continuous or discrete - controlled by vocab_size. Does not shift the input to the right, i.e. the output for timestep t is based on inputs up to timestep t inclusively. Args: vocab_size: int or None: vocab size if running on discrete input, None otherwise. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A ReZero transformer decoder as a layer that maps from a continuous or discrete tensor to a continuous tensor. """ positional_encoder = [(tl.Embedding(vocab_size, d_model) if vocab_size is not None else tl.Dense(d_model)), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] decoder_blocks = [ # pylint: disable=g-complex-comprehension _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks positional_encoder, # vecs decoder_blocks, # vecs tl.LayerNorm(), # vecs )
def SkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] return tl.Serial( tl.ShiftRight(mode=mode), embedder, SkippingSerial( [ transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_layers) ], mode=mode), tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Serial( # tokens tl.Dup(), # toks toks tl.Parallel(embedder, tl.PaddingMask()), # vecs mask [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_layers) ], # vecs mask tl.Parallel([], tl.Drop()), # ____ 0 tl.LayerNorm(), # vecs tl.Mean(axis=1), # Average on length. # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs )
def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ]
def TransformerLM(vocab_size=33300, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=4096, mode='train', ff_activation=tl.Relu): """Returns a Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size (int): vocab size. d_model (int): depth of embedding. d_ff (int): depth of feed-forward layer. n_layers (int): number of decoder layers. n_heads (int): number of attention heads. dropout (float): dropout rate (how much to drop out). max_len (int): maximum symbol length for positional encoding. mode (str): 'train', 'eval' or 'predict', predict mode is for fast inference. ff_activation (function): the non-linearity in feed-forward layer. Returns: trax.layers.combinators.Serial: A Transformer language model as a layer that maps from a tensor of tokens nnn to activations over a vocab set. """ # Embedding inputs and positional encoder positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode) ] encoder_blocks = [ Encoder(d_model, d_ff, n_heads, dropout, mode, ff_activation) for _ in range(n_layers) ] return tl.Serial(tl.ShiftRight(mode=mode), positional_encoder, encoder_blocks, tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax())
def BERT(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel( word_embeddings, type_embeddings, [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [ tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model) ] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT( embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def FunnelTransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), strides=(2,), separate_cls=True): """Returns a Funnel Encoder. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] encoder_blocks = [] n_encoder_segments = len(encoder_segment_lengths) for i in range(n_encoder_segments): # Building i'th segment for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) # If not last segment, add funnel block if i != n_encoder_segments - 1: encoder_blocks.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls)) cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1) # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. cls_pooling, # cls tl.Dense(n_classes), # cls )
def FunnelTransformer(vocab_size, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_decoder_blocks=2, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), separate_cls=True): """Returns a Full Funnel Transformer, that can be used for example for BERT. This model outputs token-level categorical distributions over all vocab: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions over `vocab_size` categories for each token; shape is (batch_size, sequence_length, vocab_size). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_decoder_blocks: Number of transformer blocks in the upsampling decoder. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] n_encoder_segments = len(encoder_segment_lengths) encoder_blocks_before_first_pooling = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(encoder_segment_lengths[0])] encoder_blocks_from_first_pooling = [] for i in range(1, n_encoder_segments): # Building i'th segment # Add funnel block between segments encoder_blocks_from_first_pooling.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size=pool_size, strides=pool_size, separate_cls=separate_cls)) for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks_from_first_pooling.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_decoder_blocks)] total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1) # Assemble and return the model. return tl.Serial( # toks tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks_before_first_pooling, # vecs masks tl.Select([0, 1, 0, 1]), # vecs masks residual = vecs old_masks encoder_blocks_from_first_pooling, # vecs masks residual masks tl.Select([0, 2, 3]), # vecs residual masks tl.Parallel( # residual from first segment is taken before # normalization, so apply it now None, tl.LayerNorm(), None), # vecs norm(residual) masks _Upsampler(total_pool_size, separate_cls), # vecs masks decoder_blocks, tl.Select([0], n_in=2), # vecs tl.LayerNorm(), tl.Dense(vocab_size), )
def test_simple_call(self): layer = tl.PositionalEncoding(max_len=8) x = np.array([[[2.0, 3.0, 4.0, 5.0], [1.0, 2.0, 3.0, 4.0]]]) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (1, 2, 4))
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn(lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn(lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax() )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) d_emb = d_model if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) d_emb //= 2 elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_emb, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( concatenate_input_chunks, tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. mode: str: 'train' or 'eval' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2) else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) concatenate_output_chunks = [] positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( concatenate_input_chunks, tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode) for _ in range(n_layers) ] + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), concatenate_output_chunks, )