def AttentionPosition(positions, d_model, n_heads=8, dropout=0.0, mode='train'): """Transformer-style multi-headed attention.""" return tl.Serial( tl.Dup(), tl.Dup(), tl.Parallel( ApplyAndQueryPositions( tl.Dense(d_model), pos=[SumLearnedPick(positions) for _ in range(n_heads)]), PreservePosition(tl.Dense(d_model)), PreservePosition(tl.Dense(d_model)), ), tl.Parallel( CopyHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), MixHeadsPos(h=n_heads), ), tl.PureAttention(d_model=d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Parallel([], tl.Drop()), # Drop the mask. CombineHeadsPos(h=n_heads), PreservePosition(tl.Dense(d_model)), )
def SerializedModel( seq_model, observation_serializer, action_serializer, significance_decay, ): """Wraps a world model in serialization machinery for training. The resulting model takes as input the observation and action sequences, serializes them and interleaves into one sequence, which is fed into a given autoregressive model. The resulting logit sequence is deinterleaved into observations and actions, and the observation logits are returned together with computed symbol significance weights. Args: seq_model: Trax autoregressive model taking as input a sequence of symbols and outputting a sequence of symbol logits. observation_serializer: Serializer to use for observations. action_serializer: Serializer to use for actions. significance_decay: Float from (0, 1) for exponential weighting of symbols in the representation. Returns: A model of signature (obs, act, obs, mask) -> (obs_logits, obs_repr, weights), where obs are observations (the second occurrence is the target), act are actions, mask is the observation mask, obs_logits are logits of the output observation representation, obs_repr is the target observation representation and weights are the target weights. """ # pylint: disable=no-value-for-parameter weigh_by_significance = [ # (mask,) RepresentationMask(serializer=observation_serializer), # (repr_mask) SignificanceWeights(serializer=observation_serializer, decay=significance_decay), # (mask, sig_weights) ] return tl.Serial( # (obs, act, obs, mask) tl.Parallel(Serialize(serializer=observation_serializer), Serialize(serializer=action_serializer), Serialize(serializer=observation_serializer)), # (obs_repr, act_repr, obs_repr, mask) Interleave(), # (obs_act_repr, obs_repr, mask) seq_model, # (obs_act_logits, obs_repr, mask) Deinterleave(x_size=observation_serializer.representation_length, y_size=action_serializer.representation_length), # (obs_logits, act_logits, obs_repr, mask) tl.Parallel(None, tl.Drop(), None, weigh_by_significance), # (obs_logits, obs_repr, weights) )
def __init__( self, seq_model, observation_serializer, action_serializer, significance_decay, mode='train', ): """Initializes SerializedModel. Args: seq_model: Trax autoregressive model taking as input a sequence of symbols and outputting a sequence of symbol logits. observation_serializer: Serializer to use for observations. action_serializer: Serializer to use for actions. significance_decay: Float from (0, 1) for exponential weighting of symbols in the representation. mode: 'train' or 'eval'. """ assert mode in ('train', 'eval') weigh_by_significance = [ # (mask,) RepresentationMask(serializer=observation_serializer), # (repr_mask) SignificanceWeights(serializer=observation_serializer, decay=significance_decay), # (mask, sig_weights) ] super().__init__( # (obs, act, obs, mask) tl.Parallel(Serialize(serializer=observation_serializer), Serialize(serializer=action_serializer), Serialize(serializer=observation_serializer)), # (obs_repr, act_repr, obs_repr, mask) Interleave(), # (obs_act_repr, obs_repr, mask) seq_model(mode=mode), # (obs_act_logits, obs_repr, mask) Deinterleave(x_size=observation_serializer.representation_length, y_size=action_serializer.representation_length), # (obs_logits, act_logits, obs_repr, mask) tl.Parallel(None, tl.Drop(), None, weigh_by_significance), # (obs_logits, obs_repr, weights) ) self._seq_model = seq_model self._observation_serializer = observation_serializer self._action_serializer = action_serializer
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size: int: vocab size n_classes: how many classes on output d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ embedder = [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='emb_dropout', mode=mode), tl.PositionalEncoding(max_len=max_len), ] return tl.Serial( # tokens tl.Dup(), # toks toks tl.Parallel(embedder, tl.PaddingMask()), # vecs mask [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_layers) ], # vecs mask tl.Parallel([], tl.Drop()), # ____ 0 tl.LayerNorm(), # vecs tl.Mean(axis=1), # Average on length. # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs )
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. The input to the model is a tensor of tokens (ints). Args: vocab_size: int: vocab size d_model: int: depth of embedding (n_units in the RNN cell) n_layers: int: number of RNN layers rnn_cell: the RNN cell rnn_cell_d_state_multiplier: how many times is RNN cell state larger dropout: float: dropout rate (how much to drop out) mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ def MultiRNNCell(): """Multi-layer RNN cell.""" assert n_layers == 2 return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers)) ) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, name='embedding', mode=mode), tl.Dup(), # Duplicate to create parallel state. tl.Parallel([], tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier)), tl.Scan(MultiRNNCell(), axis=1), tl.Parallel([], tl.Drop()), # Drop RNN state. tl.Dense(vocab_size), tl.LogSoftmax() )
def PositionLookupTransformerLM(vocab_size=128, d_model=256, d_ff=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial( tl.ShiftRight(), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.Dup(), tl.Parallel([], NewPositionalEncoding(positions=positions)), [DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers)], tl.Parallel([], tl.Drop()), # Drop positions. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax() )
def test_drop(self): layer = tl.Drop() x = np.array([1, 2, 3]) y = layer(x) self.assertEqual(as_list(y), [])
'weights', # Model weights. 'slots', # Per-parameter optimizer state, e.g. gradient moments. 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. ]) _DEFAULT_METRICS = { 'loss': tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()), 'accuracy': tl.Accuracy(), 'sequence_accuracy': tl.SequenceAccuracy(), 'neg_log_perplexity': tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss(), tl.Negate()), 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), } class Trainer: """Trax trainer. A trainer allows to make training steps, train for full epochs, save the training state and access evaluation data. """ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
OptState = collections.namedtuple( '_OptState', [ 'weights', # Model weights. 'slots', # Per-parameter optimizer state, e.g. gradient moments. 'opt_params', # Optimizer (hyper)parameters, e.g. learning rate, momentum. ]) _DEFAULT_METRICS = { 'loss': tl.WeightedCategoryCrossEntropy(), 'accuracy': tl.WeightedCategoryAccuracy(), 'sequence_accuracy': tl.MaskedSequenceAccuracy(), 'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(), tl.Negate()), 'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()), } class Trainer: """Trax trainer. A trainer allows to make training steps, train for full epochs, save the training state and access evaluation data. """ def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEmbedder(vocab_size): # tokens --> vectors return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] def EncoderBlocks(n_blocks): # vectors masks --> vectors masks return [ _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_blocks) ] def EncoderDecoderBlocks(n_blocks): # vectors masks --> vectors masks return [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_blocks) ] in_embed = PositionalEmbedder(input_vocab_size) out_embed = (in_embed if output_vocab_size is None else PositionalEmbedder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch(in_embed, tl.PaddingMask()), # vec_e masks ..... ..... EncoderBlocks(n_encoder_layers), # vec_d masks ..... ..... tl.LayerNorm(), # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(), # tok_d ..... ..... ..... out_embed, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... EncoderDecoderBlocks(n_decoder_layers), # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Parallel([], tl.Drop(), tl.Drop()), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward mode: str: 'train' or 'eval' Returns: the layer. """ if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn( lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Parallel([], tl.Drop()), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn( lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax())