def test_swap(self): layer = tl.Swap() xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] ys = layer(xs) self.assertEqual(as_list(ys), [[10, 20, 30], [1, 2, 3]])
def __init__(self, pre_attention, attention, post_attention): self.pre_attention = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(pre_attention, [], []), ) assert hasattr(attention, 'forward_and_backward') self.attention = ApplyAttentionWrapper(attention) self.post_attention = tl.Parallel(post_attention, [], []) layers = [ self.pre_attention, self.attention, self.post_attention, tl.Parallel(tl.Add(), []), ] super(ReversibleAttentionHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [ self.pre_attention, self.attention, self.post_attention, self.subtract_top, ]
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # (x1_or_y1, x2) -> (x2, x1_or_y1, x2) tl.Parallel([], tl.Dup()), tl.Swap(), tl.Parallel(residual_layers, [], []), ) layers = [self.compute_residual, tl.Parallel(tl.Add(), [])] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), []) self.reverse_layers = [self.compute_residual, self.subtract_top]
def loss(mask_id=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel( layers.CrossEntropyLossScalar(mask_id, has_weights), layers.L2LossScalar(mask_id, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.MulConstant(constant=0.0))
def loss(id_to_mask=None, has_weights=False): """Cross-entropy loss as scalar compatible with Trax masking.""" return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel(layers.CrossEntropyLoss(id_to_mask, has_weights), layers.L2Loss(id_to_mask, has_weights)), # Add both losses. layers.Add(), # Zero out in this test. layers.Fn(lambda x: x * 0.0), )
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode, ff_activation): """Transformer encoder-decoder layer. The input is a triple (decoder_input, mask, encoder) where the mask is created from the original source to prevent attending to the padding part of the encoder. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) layer_idx: which layer are we at (for bookkeeping) mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: the layer, returning a triple (decoder_activations, mask, encoder). """ decoder_self_attention = [ # vecs_d pmask vecs_e tl.LayerNorm(), # vecs_d ..... ...... tl.BasicCausalAttention(d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d ..... ...... ] decoder_to_encoder_attention = [ # vecs_d masks vecs_e tl.LayerNorm(), # vecs_d masks vecs_e tl.Parallel([], [], tl.Dup()), # ______ _____ vecs_e vecs_e tl.Parallel([], tl.Swap()), # ______ vecs_e masks ...... tl.Parallel([], tl.Dup()), # ______ vecs_e vecs_e ..... ...... tl.AttentionQKV( # (q k v masks ... --> vecs_d masks ...) d_model, n_heads=n_heads, dropout=dropout, mode=mode), tl.Dropout(rate=dropout, mode=mode), # vecs_d mask vecs_e ] feed_forward = [ FeedForward(d_model, d_ff, dropout, layer_idx, mode, ff_activation), ] return tl.Serial( # vecs_d masks vecs_e tl.Residual(decoder_self_attention), # vecs_d masks vecs_e tl.Residual(decoder_to_encoder_attention), # vecs_d masks vecs_e tl.Residual(feed_forward), # vecs_d masks vecs_e )
def loss(): """Cross-entropy loss as scalar compatible with Trax masking.""" ones = layers.Fn(lambda x: math.numpy.ones_like(x)) # pylint: disable=unnecessary-lambda return layers.Serial( # Swap from (pred-obs, pred-reward, target-obs, target-reward) # to (pred-obs, target-obs, pred-reward, target-reward). layers.Parallel([], layers.Swap()), # Duplicate target-obs and target-reward and make 1 to add weights. layers.Parallel([], layers.Branch([], ones)), layers.Parallel([], [], [], [], layers.Branch([], ones)), # Cross-entropy loss for obs, L2 loss on reward. layers.Parallel(layers.CrossEntropyLoss(), layers.L2Loss()), # Add both losses. layers.Add(), # Zero out in this test. layers.Fn(lambda x: x * 0.0), )
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # x1_or_y1, x2, ... tl.Parallel([], tl.Dup()), # x1_or_y1, x2, x2, ... tl.Swap(), # x2, x1_or_y1, x2, ... tl.Parallel([], [], residual_layers), # x2, x1_or_y1, residual, ... tl.Select([2, 1, 0]), # residual, x1_or_y1, x2, ... ) self.n_preserve = self.compute_residual.n_out - 2 parallel_preserve = [[]] * self.n_preserve layers = [ self.compute_residual, tl.Parallel(tl.Add(), *parallel_preserve) ] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve) self.reverse_layers = [self.compute_residual, self.subtract_top]
def ApplyAndQueryPositions(layer, pos): """Execute layer without position and pos-layers on positions. This takes an embedding including position x = (emb, p), and outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p) where pos=[pos1...posn]. Args: layer: layer to be executed without position information. pos: list of layers to be applied to positions. Returns: the result of this application. """ n_heads = len(pos) return tl.Serial( tl.Dup(), # (x, x) CutAtPosition(), # (x_content, x_position, x) tl.Parallel([], tl.Swap()), # (x_content, x, x_position) [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)], # Now the stack is x_content, (x, x_position) * n_heads. tl.Parallel(*([layer] + pos)), tl.Concatenate(n_items=n_heads + 1))
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def LayerDropSkippingTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu, skip_fraction=0.4): """Returns a Skipping Transformer language model. The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer skip_fraction: fraction of times to skip some layers Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if n_layers_to_keep > current_layer_num LargerThan(float(current_layer_num)), # then: run block tl.Serial(transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)), # else: run noop tl.Serial() ) # stack: embedding, n_layers_to_keep ) if mode == 'train': minimum_layers = 0.0 maximum_layers = float(n_layers) / skip_fraction else: minimum_layers = maximum_layers = float(n_layers) return tl.Serial( tl.ShiftRight(mode=mode), embedder, # stack: embedding tl.RandomUniform(minimum_layers, maximum_layers, sync=True), # stack: n_layers_to_keep, embedding tl.Swap(), # stack: embedding, n_layers_to_keep [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep tl.Select([0], n_in=2), # stack: embedding tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax(), )
def EveryOtherLayerDropTransformerLM(vocab_size, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu, skip_mode='even', skip_fraction=0.5, eval_skip_fraction=0.0): """Returns an "EveryOther" LayerDrop Transformer language model. During each training step it either runs all layers, or skips a subset of layers. This subset is the same every time, and it is specified by "skip_mode". The input to the model is a tensor of tokens. (This model uses only the decoder part of the overall Transformer.) Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of encoder/decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference ff_activation: the non-linearity in feed-forward layer skip_mode: which layers to skip when skipping: even/odd/1half/2half. skip_fraction: fraction of times to skip layers eval_skip_fraction: fraction of times to skip layers during eval Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len, mode=mode), ] if mode == 'train': pass else: skip_fraction = eval_skip_fraction skip_mode_funs = { # which layers should be skipped? 'even': (lambda num: num%2 == 0), # 0th layer is even 'odd': (lambda num: num%2 == 1), '1half': (lambda num: num < (n_layers/2)), '2half': (lambda num: num >= (n_layers/2)), } skip_mode_fun = skip_mode_funs[skip_mode] @assert_shape('...sd,->...sd,') def ConditionedBlock(current_layer_num): return tl.Serial( # stack: embedding, n_layers_to_keep tl.Select([1, 0, 1]), # n_layers_to_keep, embedding, n_layers_to_keep tl.Cond( # if random() > skip_fraction OR layer not in skip_mode ... LargerThan(skip_fraction if skip_mode_fun(current_layer_num ) else 0.0), # then: run block tl.Serial( transformer._DecoderBlock( # pylint: disable=g-complex-comprehension,protected-access d_model, d_ff, n_heads, dropout, [], mode, ff_activation)) # else: noop (implicit) ) # stack: embedding, n_layers_to_keep ) return tl.Serial( tl.ShiftRight(mode=mode), embedder, # stack: embedding tl.RandomUniform(0., 1., sync=True), # stack: n_layers_to_keep, embedding tl.Swap(), # stack: embedding, n_layers_to_keep [ConditionedBlock(i) for i in range(n_layers)], # stack: embedding, n_layers_to_keep tl.Select([0], n_in=2), # stack: embedding tl.LayerNorm(), tl.Dense(vocab_size), )
def Dup2(): """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...).""" return [ # Stack is (a, b, ...) tl.Parallel(tl.Dup(), tl.Dup()), # Stack is (a, a, b, b, ...) tl.Parallel([], tl.Swap()) # Stack is (a, b, a, b, ...) ]