def test_run_reversible_same_as_default_extended(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) # We want to test rng propagation too, so adding some dropout layers. first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) rev_layers1 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), tl.ReversibleSwap(), tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), tl.ReversibleSwap() ] mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) rev_layers2 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), tl.ReversibleSwap() ] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), tl.LogSoftmax(), tl.CrossEntropyLoss()) model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer]) rng_init = fastmath.random.get_prng(12) model.init(labeled_batch, rng=rng_init) optimizer_fn = optimizers.Adam # to test slots # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model.weights) trainer = optimizers.Trainer(model, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_layer_weights1 = first_layer.weights rev_layer12_weights1 = rev_layers1[2].weights mid_layer_weights1 = mid_layer.weights rev_layer20_weights1 = rev_layers2[0].weights loss_layer_weights1 = loss_layer.weights # Now make 3 steps with reversible trainer. model.init(labeled_batch, rng=rng_init) trainer = optimizers.ReversibleSerialTrainer( [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_layer_weights1, loss_layer.weights) self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) self._assert_all_equal(mid_layer_weights1, mid_layer.weights) self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) self._assert_all_equal(first_layer_weights1, first_layer.weights)
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, n_attention_layers=1, n_feedforward_layers=1, center_layernorm=True, use_bfloat16=False, mode='train'): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size n_attention_layers: how many residual causal attention layers should we have before the feed-forward block (default: 1, the standard block) n_feedforward_layers: how many FFNN layers should we have (default 1). center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False). mode: str: 'train' or 'eval' Returns: the layer. """ # pylint: disable=g-complex-comprehension attention_half_residuals = [ [tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=ct.ApplyAttentionLayer( attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode), name='ReversibleHalfResidualDecoderAttn'), tl.ReversibleSwap() ] for _ in range(n_attention_layers)] feed_forwards = [ [tl.ReversibleHalfResidual( ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16), name='ReversibleHalfResidualDecoderFF'), tl.ReversibleSwap() ] for _ in range(n_feedforward_layers)] # pylint: enable=g-complex-comprehension return attention_half_residuals + feed_forwards
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: the layer. """ enc_dec_attention = tl.EncDecAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = attention_type( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = FeedForwardWithOptions( d_model, d_ff, dropout, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' Returns: the layer. """ enc_dec_attention = tl.EncDecAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _attention_half_residual(): return [ tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), attention_layer=_Attn(), name='ReversibleHalfResidualDecoderAttn'), tl.ReversibleSwap() ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_use_sru, ff_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ attention = attention_type(n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value, causal=True, output_dropout=dropout, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) if ff_use_sru: feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)] else: feed_forward = [ ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, dropout, ff_chunk_size, mode) ] return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _feed_forward(): layers = [ tl.ReversibleHalfResidual(_FF(), name='ReversibleHalfResidualEncoderFF') ] if use_two_swaps_per_block: layers.append(tl.ReversibleSwap()) return layers
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train' or 'eval' Returns: the layer. """ attention = ct.ApplyAttentionLayer(attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: str: 'train' or 'eval' Returns: the layer. """ # TODO(lukaszkaiser): unify attention layers API and remove this branch try: attention = attention_type( n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value, causal=True, output_dropout=dropout, mode=mode) except TypeError: # No d_qk arguments in less advanced layers. attention = attention_type(d_model, n_heads=n_heads, dropout=dropout, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = FeedForwardWithOptions( d_model, d_ff, dropout, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def test_run_reversible_slots(self): """Tests that slots can be read and assigned in reversible trainer.""" layers = [tl.Dense(4), tl.Dup()] rev_layers = [tl.ReversibleHalfResidual(tl.Dense(4)), tl.ReversibleSwap()] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(4), tl.LogSoftmax(), tl.CrossEntropyLoss()) trainer = optimizers.ReversibleSerialTrainer( [(layers, rev_layers)], loss_layer, optimizers.Adam) slots = trainer.slots trainer.slots = slots self.assertEqual(slots, trainer.slots)
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 18GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.global_device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16 * 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 18 # 18 layers each 16K x 16K = 256M weights ~= 1GB, 18GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16 * 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16 * 1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # Set to true to run again, e.g., for profiling. run_twice = False if run_twice: t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_run_reversible_weights_trainsfer_xprof(self): """Runs the reversible trainer and profiles weight transfer stats.""" run_this_test = False # We only run this test manually. if not run_this_test or fastmath.global_device_count( ) == 1: # TPU only return # Create inputs and rngs. inputs_batch = np.ones((1024, 128), dtype=np.int32) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(4, 1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 6 rev_layers = [] int_shape = shapes.ShapeDtype((1024, 128), dtype=np.int32) shape = shapes.ShapeDtype((1024, 128, 1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu( layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.SGD # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( [(first_layer, rev_layers)], loss_layer, optimizer_fn) loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. # We profile here. t = time.time() loss, _ = trainer.one_step(labeled_batch, rng_step) self.assertLess(float(loss.sum()), 10000.0) # Just to get the loss. print('Took %.3f seconds to run, loss %s' % (time.time() - t, loss))
def test_train_memory_efficient(self): """Trains a large network in a memory-efficient way.""" # This test requires > 16GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create the model. n_layers = 16 # 16 layers each 16K x 16K = 256M weights ~= 1GB, 16GB ram model = tl.Serial( tl.Embedding(9, 16 * 1024), tl.Dup(), [[ tl.ReversibleHalfResidual(tl.Dense(16 * 1024)), tl.ReversibleSwap() ] for _ in range(n_layers)], tl.Concatenate(), tl.Dense(9), ) # Create inputs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) def _data_gen(): while True: yield labeled_batch # Run training. loss_layer = tl.WeightedCategoryCrossEntropy() task = training.TrainTask(_data_gen(), loss_layer, optimizers.Adafactor) eval_task = training.EvalTask(_data_gen(), [tl.WeightedCategoryCrossEntropy()]) loop = training.Loop(model, [task], eval_tasks=[eval_task], eval_at=lambda step_n: step_n == 2, use_memory_efficient_trainer=True) self.assertEqual(0, loop.step) loop.run(n_steps=2) self.assertEqual(2, loop.step)
def test_run_reversible_large_weights(self): """Runs the reversible trainer with a lot of weights to test memory use.""" # This test requires > 20GB RAM, only run on TPUs. It does pass on GPU # and CPU when you run it locally, but it's too big for unit-testing. ram_limited = True # Set to False to run this test locally. if fastmath.device_count() == 1 and ram_limited: return # Create inputs and rngs. inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) first_layer = tl.Serial(tl.Embedding(9, 16*1024), tl.Dup()) rng_init = fastmath.random.get_prng(12) rng_step = fastmath.random.get_prng(13) # Initialize layers. first_layer.init(labeled_batch, rng=rng_init) n_layers = 20 # 20 layers each 16K x 16K = 256M weights ~= 1GB, 20GB ram rev_layers = [] int_shape = shapes.ShapeDtype((2, 4), dtype=np.int32) shape = shapes.ShapeDtype((2, 4, 16*1024)) sig = (shape, shape) for _ in range(n_layers): layer = tl.ReversibleHalfResidual(tl.Dense(16*1024)) layer.init(sig, rng=rng_init) layer.weights = tl.on_cpu(layer.weights) # store weights in cpu memory rev_layers.append(layer) rev_layers.append(tl.ReversibleSwap()) loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(9), tl.LogSoftmax(), tl.CrossEntropyLoss()) loss_layer.init((shape, shape, int_shape, int_shape)) optimizer_fn = optimizers.Adafactor # Make a step with reversible trainer. trainer = optimizers.ReversibleSerialTrainer( first_layer, rev_layers, loss_layer, optimizer_fn) trainer.one_step(labeled_batch, rng_step)
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, use_bfloat16=False, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size use_bfloat16: whether to use bfloat16 for weights (default: False) mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. if attention.n_out == 2: def reshape_mask(mask): return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) attention = tl.Serial( tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2), attention, tl.Select([0], n_in=2)) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, use_bfloat16) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _feed_forward(): return [ tl.ReversibleHalfResidual(_FF(), name='ReversibleHalfResidualDecoderFF'), tl.ReversibleSwap() ]