def test_run_reversible_weights_trainsfer_xprof(self): """Runs the reversible trainer and profiles weight transfer stats.""" run_this_test = False # We only run this test manually. if fastmath.device_count() == 1 and run_this_test: # TPU only return # Create inputs and rngs. inputs_batch = np.ones((1024, 128), dtype=np.int32) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 6 rev_layers = [] int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) shape = shapes.ShapeDtype((1024, 128, 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.SGD # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # We profile here. t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def __init__(self, residual_layers): self.compute_residual = tl.Serial( # x1_or_y1, x2, ... tl.Parallel([], tl.Dup()), # x1_or_y1, x2, x2, ... tl.Swap(), # x2, x1_or_y1, x2, ... tl.Parallel([], [], residual_layers), # x2, x1_or_y1, residual, ... tl.Select([2, 1, 0]), # residual, x1_or_y1, x2, ... ) self.n_preserve = self.compute_residual.n_out - 2 parallel_preserve = [[]] * self.n_preserve layers = [ self.compute_residual, tl.Parallel(tl.Add(), *parallel_preserve) ] super(ReversibleHalfResidual, self).__init__(layers) self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve) self.reverse_layers = [self.compute_residual, self.subtract_top]
def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True): """ReversibleSerial but with a forgetting block every n_layers.""" if not n_layers or len(layers) <= n_layers + 1: return tl.ReversibleSerial(layers) layers1, layers2 = layers[:n_layers], layers[n_layers:] if forget_dense: forgetting_layer = tl.Serial( _XYAvg(), tl.Dense(d_model), tl.Dup(), ) else: forgetting_layer = tl.Select([0, 1]) return tl.Serial( tl.ReversibleSerial(layers1), forgetting_layer, _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense) )
def RawPolicy(seq_model, n_controls, n_actions): """Wraps a sequence model in a policy interface. The resulting model takes as input observation anc action sequences, but only uses the observations. Adds output heads for action logits and value predictions. Args: seq_model: Trax sequence model taking as input and outputting a sequence of continuous vectors. n_controls: Number of controls. n_actions: Number of action categories in each control. Returns: A model of signature (obs, act) -> (act_logits, values), with shapes: obs: (batch_size, length + 1, obs_depth) act: (batch_size, length, n_controls) act_logits: (batch_size, length, n_controls, n_actions) values: (batch_size, length) """ def SplitControls(): # pylint: disable=invalid-name """Splits logits for actions in different controls.""" def f(x): return jnp.reshape(x, x.shape[:2] + (n_controls, n_actions)) return tl.Fn('SplitControls', f) action_head = [ # Predict all action logits at the same time. tl.Dense(n_controls * n_actions), # Then group them into separate controls, adding a new dimension. SplitControls(), tl.LogSoftmax(), ] return tl.Serial( # (obs, act) tl.Select([0], n_in=2), # (obs,) seq_model, # (obs_hidden,) tl.Dup(), # (obs_hidden, obs_hidden) tl.Parallel(action_head, [tl.Dense(1), tl.Flatten()]) # (act_logits, values) )
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 20 # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16 * 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step)
def test_train_memory_efficient(self): """Trains a large network in a memory-efficient way.""" # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create the model. n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram model = tl.Serial( tl.Embedding(9, 16*1024), tl.Dup(), [[tl.ReversibleHalfResidual(tl.Dense(16*1024)), tl.ReversibleSwap()] for _ in range(n_layers)], tl.Concatenate(), tl.Dense(9), ) # Create inputs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) def _data_gen(): while True: yield labeled_batch # Run training. loss_layer = tl.WeightedCategoryCrossEntropy() task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) eval_task = training.EvalTask(_data_gen(), [tl.WeightedCategoryCrossEntropy()]) loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n == 2, use_memory_efficient_trainer=True) self.assertEqual(0, loop.step) loop.run(n_steps=2) self.assertEqual(2, loop.step)
def ApplyAndQueryPositions(layer, pos): """Execute layer without position and pos-layers on positions. This takes an embedding including position x = (emb, p), and outputs layer(emb).pos1(x, p).....layer(emb).posn(x, p) where pos=[pos1...posn]. Args: layer: layer to be executed without position information. pos: list of layers to be applied to positions. Returns: the result of this application. """ n_heads = len(pos) return tl.Serial( tl.Dup(), # (x, x) CutAtPosition(), # (x_content, x_position, x) tl.Parallel([], tl.Swap()), # (x_content, x, x_position) [tl.Parallel([], Dup2()) for _ in range(n_heads - 1)], # Now the stack is x_content, (x, x_position) * n_heads. tl.Parallel(*([layer] + pos)), tl.Concatenate(n_items=n_heads + 1))
def PositionLookupTransformerLM(vocab_size=128, d_model=256, d_ff=512, n_layers=3, n_heads=4, dropout=0.1, max_len=100, mode='train'): """Transformer language model (only uses the decoder part of Transformer). Args: vocab_size: int: vocab size d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_layers: int: number of layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: maximal length mode: str: 'train' or 'eval' Returns: the layer. """ positions = _POSITIONS[:max_len, :] return tl.Serial( tl.ShiftRight(), tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.Dup(), tl.Parallel([], NewPositionalEncoding(positions=positions)), [DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode) for _ in range(n_layers)], tl.Parallel([], tl.Drop()), # Drop positions. tl.LayerNorm(), tl.Dense(vocab_size), tl.LogSoftmax() )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. mode: str: 'train' or 'eval' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] concatenate_output_chunks = tl.Concatenate(n_items=n_chunks, axis=-2) else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) concatenate_output_chunks = [] positional_embedder = [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.PositionalEncoding(max_len=max_len), ] return tl.Model( concatenate_input_chunks, tl.ShiftRight(), positional_embedder, tl.Dup(), tl.ReversibleSerial([ # pylint: disable=g-complex-comprehension DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode) for _ in range(n_layers) ] + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), concatenate_output_chunks, )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, n_layers_forget=0, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity n_layers_forget: how often to have a forgetting block between layers mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) def PositionalEnc(mode): return PositionalEncoding(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [PositionalEnc(encoder_mode)] out_encoder = out_embedder + [PositionalEnc(mode)] if output_vocab_size is None: output_vocab_size = input_vocab_size # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # vec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given encoder mask. tl.Select([1, 0]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys mode: str: 'train' or 'eval' Returns: the layer. """ if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def test_noop_dup(self): layer = tl.Branch([], tl.Dup()) x = np.array([1, 2, 3]) ys = layer(x) self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]])
def test_default_name(self): layer = tl.Serial(tl.Dup(), tl.Dup()) self.assertIn('Serial', str(layer))
def test_default_name(self): layer = tl.Parallel(tl.Dup(), tl.Dup()) self.assertIn('Parallel', str(layer))
def test_exception_n_out(self): cond = SmallerThan(3.0) true = DivideBy(2.) false = tl.Dup() self.assertRaises(ValueError, lambda: tl.Cond(cond, true, false))
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) d_emb = d_model if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) d_emb //= 2 elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_emb, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( concatenate_input_chunks, tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(d_model, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ # tok_e mask tok_d ..... tl.PaddingMask(), tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]), # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def test_run_reversible_same_as_default_extended(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) # We want to test rng propagation too, so adding some dropout layers. first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) rev_layers1 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), tl.ReversibleSwap(), tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), tl.ReversibleSwap() ] mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) rev_layers2 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), tl.ReversibleSwap() ] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), tl.LogSoftmax(), tl.CrossEntropyLoss()) model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer]) rng_init = fastmath.random.get_prng(12) model.init(labeled_batch, rng=rng_init) optimizer_fn = optimizers.Adam # to test slots # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model.weights) trainer = optimizers.Trainer(model, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_layer_weights1 = first_layer.weights rev_layer12_weights1 = rev_layers1[2].weights mid_layer_weights1 = mid_layer.weights rev_layer20_weights1 = rev_layers2[0].weights loss_layer_weights1 = loss_layer.weights # Now make 3 steps with reversible trainer. model.init(labeled_batch, rng=rng_init) # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? trainer = optimizers.ReversibleSerialTrainer( [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn, memoize_jit=False) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_layer_weights1, loss_layer.weights) self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) self._assert_all_equal(mid_layer_weights1, mid_layer.weights) self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) self._assert_all_equal(first_layer_weights1, first_layer.weights)
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ if not hasattr(attention_type, 'forward_unbatched'): if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] attention_half_residual = ReversibleAttentionHalfResidual( pre_attention, attention, post_attention) else: attention = attention_type( n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value, share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode) attention_half_residual = ReversibleHalfResidualV2( tl.LayerNorm(), attention_layer=attention, ) if ff_use_sru: feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)] else: feed_forward = [ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, dropout, ff_chunk_size, mode)] return [ attention_half_residual, tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def SerializedPolicy(seq_model, n_controls, n_actions, observation_serializer, action_serializer): """Wraps a policy in serialization machinery for training. The resulting model takes as input observation and action sequences, and serializes them into one sequence similar to SerializedModel, before passing to the given sequence model. Adds output heads for action logits and value predictions. Args: seq_model: Trax sequence model taking as input a sequence of symbols and outputting a sequence of continuous vectors. n_controls: Number of controls. n_actions: Number of action categories in each control. observation_serializer: Serializer to use for observations. action_serializer: Serializer to use for actions. Returns: A model of signature (obs, act) -> (act_logits, values), same as in RawPolicy. """ if action_serializer.representation_length != n_controls: raise ValueError( 'Action symbols should correspond 1-1 to controls, but got {} ' 'controls and {} symbols.'.format( n_controls, action_serializer.representation_length)) @tl.layer() def FirstSymbol(x, **unused_kwargs): return x[:, :, 0] @tl.layer() def PadRight(x, n_to_pad, **unused_kwargs): pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) return np.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) action_head = [ # Drop the dummy action introduced before. DropLastTimestep(), # pylint: disable=no-value-for-parameter tl.Dense(n_actions), tl.LogSoftmax(), ] value_head = [ # Take just the vectors corresponding to the first action symbol. FirstSymbol(), # pylint: disable=no-value-for-parameter # Predict values. tl.Dense(1), # Get rid of the singleton dimension. tl.Flatten(), ] return tl.Serial([ # (obs, act) tl.Parallel( Serialize(serializer=observation_serializer), # pylint: disable=no-value-for-parameter Serialize(serializer=action_serializer), # pylint: disable=no-value-for-parameter ), # (obs_repr, act_repr) Interleave( # pylint: disable=no-value-for-parameter ), # (obs_act_repr,) # Add one dummy action to the right - we'll use the output at its first # symbol to predict the value for the last observation. PadRight(n_to_pad=action_serializer.representation_length), # pylint: disable=no-value-for-parameter # Shift one symbol to the right, so we predict the n-th action symbol # based on action symbols 1..n-1 instead of 1..n. tl.ShiftRight(), seq_model, # (obs_act_hidden,) Deinterleave( # pylint: disable=no-value-for-parameter x_size=observation_serializer.representation_length, y_size=action_serializer.representation_length, ), # (obs_hidden, act_hidden) tl.Select([1]), # (act_hidden,) tl.Dup(), # (act_hidden, act_hidden) tl.Parallel(action_head, value_head) # (act_logits, values) ])
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn(lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn(lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax() )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train'): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def test_dup_dup(self): layer = tl.Parallel(tl.Dup(), tl.Dup()) xs = [np.array([1, 2, 3]), np.array([10, 20])] ys = layer(xs) self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [10, 20], [10, 20]])
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train', axial_pos_shape=None, d_axial_pos_embs=None, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... )
def test_custom_name(self): layer = tl.Parallel(tl.Dup(), tl.Dup(), name='DupDup') self.assertIn('DupDup', str(layer))
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding attention_type: class: attention class to use, such as SelfAttention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ positional_encoding = ct.PositionalEncoder(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) positional_embedder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, mode=mode) return tl.Serial( tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks), tl.Concatenate(), # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), # pylint: disable=no-value-for-parameter dense_loss_layer, )
def test_dup_dup(self): layer = tl.Serial(tl.Dup(), tl.Dup()) x = np.array([1, 2, 3]) ys = layer(x) self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]])
def FunnelTransformerLM(vocab_size, d_model=512, d_ff=2048, vanilla_layers=(0, 1), shorten_factors=(3,), n_funnel_blocks=(6,), n_heads=8, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.FastGelu): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level Transformer decoder layers before and after shortening. shorten_factors: by how much to shorten at each step - tuple of arbitrary length denoting by how much shorten at each pooling stage. n_funnel_blocks: number of Transformer decoder blocks after each stage of pooling - tuple of the same length as `shorten_factors`. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: str: 'train' or 'eval'. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ assert mode != 'predict' # For now, 'predict' mode is unsupported. assert len(n_funnel_blocks) == len(shorten_factors) token_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] context_bias_layer, location_bias_layer = _get_rel_att_inputs(d_model, n_heads) n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers def create_decoder_blocks(n_layers, total_pooling): # pylint: disable=invalid-name decoder_blocks = [ # pylint: disable=g-complex-comprehension _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling) for _ in range(n_layers)] return decoder_blocks + [tl.LayerNorm()] total_pooling_acc = 1 pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, total_pooling=1) funnel_blocks = [] for shorten_factor, block_len in zip(shorten_factors, n_funnel_blocks): funnel_blocks = funnel_blocks + [_FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=shorten_factor, resampler_fn=_DownsamplerLM)] total_pooling_acc *= shorten_factor funnel_blocks = funnel_blocks + create_decoder_blocks(block_len, total_pooling_acc) upsampling_layer = _FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=total_pooling_acc, resampler_fn=_UpsamplerLM) conv_layer = tl.Serial( tl.CausalConv(d_model, total_pooling_acc), ff_activation() ) post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, total_pooling=1) # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks token_encoder, # vecs pre_decoder_blocks, # vecs tl.Dup(), tl.ShiftRight(n_positions=total_pooling_acc - 1), funnel_blocks, tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), upsampling_layer, tl.LayerNorm(), tl.Concatenate(), conv_layer, post_decoder_blocks, tl.Dense(vocab_size), # vecs )
def test_custom_name(self): layer = tl.Serial(tl.Dup(), tl.Dup(), name='Branch') self.assertIn('Branch', str(layer))
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )