def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, center_layernorm=True, use_bfloat16=False, use_two_swaps_per_block=True, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False) use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder block, otherwise use only one swap. mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. if attention.n_out == 2: def reshape_mask(mask): return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) attention = tl.Serial( tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2), attention, tl.Select([0], n_in=2)) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=attention, name='ReversibleHalfResidualEncoderAttn') feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16) encoder_block = [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward, name='ReversibleHalfResidualEncoderFF'), ] if use_two_swaps_per_block: encoder_block.append(tl.ReversibleSwap()) return encoder_block
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: the layer. """ enc_dec_attention = tl.EncDecAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, n_attention_layers=1, n_feedforward_layers=1, center_layernorm=True, use_bfloat16=False, mode='train'): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size n_attention_layers: how many residual causal attention layers should we have before the feed-forward block (default: 1, the standard block) n_feedforward_layers: how many FFNN layers should we have (default 1). center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False). mode: str: 'train' or 'eval' Returns: the layer. """ # pylint: disable=g-complex-comprehension attention_half_residuals = [[ tl.ReversibleHalfResidual(tl.LayerNorm(center=center_layernorm), attention_layer=ct.ApplyAttentionLayer( attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode), name='ReversibleHalfResidualDecoderAttn'), tl.ReversibleSwap() ] for _ in range(n_attention_layers)] feed_forwards = [[ tl.ReversibleHalfResidual(ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16), name='ReversibleHalfResidualDecoderFF'), tl.ReversibleSwap() ] for _ in range(n_feedforward_layers)] # pylint: enable=g-complex-comprehension return attention_half_residuals + feed_forwards
def test_reversible_swap(self): layer = tl.ReversibleSwap() xs = [np.array([1, 2]), np.array([10, 20])] ys = layer(xs) self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])
def test_run_reversible_same_as_default_extended(self): """Runs the reversible trainer, check results are the same as default.""" inputs_batch = np.arange(8).reshape((2, 4)) targets_batch = 2 * inputs_batch labeled_batch = (inputs_batch, targets_batch, np.ones_like(targets_batch)) # We want to test rng propagation too, so adding some dropout layers. first_layer = tl.Serial(tl.Embedding(9, 4), tl.Dropout(0.5), tl.Dup()) rev_layers1 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.2)), tl.ReversibleSwap(), tl.ReversibleHalfResidual(tl.Dropout(0.5), tl.Dense(4)), tl.ReversibleSwap() ] mid_layer = tl.Serial(tl.Add(), tl.Dense(4), tl.Dup()) rev_layers2 = [ tl.ReversibleHalfResidual(tl.Dense(4), tl.Dropout(0.3)), tl.ReversibleSwap() ] loss_layer = tl.Serial(tl.Concatenate(), tl.Dense(19), tl.Dropout(0.3), tl.LogSoftmax(), tl.CrossEntropyLoss()) model = tl.Serial([first_layer] + rev_layers1 + [mid_layer] + rev_layers2 + [loss_layer]) rng_init = fastmath.random.get_prng(12) model.init(labeled_batch, rng=rng_init) optimizer_fn = optimizers.Adam # to test slots # Make 3 steps with the original trainer. optimizer = optimizer_fn() optimizer.tree_init(model.weights) trainer = optimizers.Trainer(model, optimizer) rng_step1 = fastmath.random.get_prng(7) rng_step2 = fastmath.random.get_prng(8) rng_step3 = fastmath.random.get_prng(9) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) first_layer_weights1 = first_layer.weights rev_layer12_weights1 = rev_layers1[2].weights mid_layer_weights1 = mid_layer.weights rev_layer20_weights1 = rev_layers2[0].weights loss_layer_weights1 = loss_layer.weights # Now make 3 steps with reversible trainer. model.init(labeled_batch, rng=rng_init) # TODO(lukaszkaiser): this test seems to fail with memoize_jit, why? trainer = optimizers.ReversibleSerialTrainer( [(first_layer.sublayers, rev_layers1), (mid_layer.sublayers, rev_layers2)], loss_layer, optimizer_fn, memoize_jit=False) trainer.one_step(labeled_batch, rng_step1) trainer.one_step(labeled_batch, rng_step2, learning_rate=0.02) trainer.one_step(labeled_batch, rng_step3, learning_rate=0.03) # Check that weights end up the same. self._assert_all_equal(loss_layer_weights1, loss_layer.weights) self._assert_all_equal(rev_layer20_weights1, rev_layers2[0].weights) self._assert_all_equal(mid_layer_weights1, mid_layer.weights) self._assert_all_equal(rev_layer12_weights1, rev_layers1[2].weights) self._assert_all_equal(first_layer_weights1, first_layer.weights)
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ if not hasattr(attention_type, 'forward_unbatched'): if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] attention_half_residual = ReversibleAttentionHalfResidual( pre_attention, attention, post_attention) else: attention = attention_type( n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value, share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode) attention_half_residual = ReversibleHalfResidualV2( tl.LayerNorm(), attention_layer=attention, ) if ff_use_sru: feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)] else: feed_forward = [ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, dropout, ff_chunk_size, mode)] return [ attention_half_residual, tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def test_reversible_swap(self, backend): with fastmath.use_backend(backend): layer = tl.ReversibleSwap() xs = [np.array([1, 2]), np.array([10, 20])] ys = layer(xs) self.assertEqual(tl.to_list(ys), [[10, 20], [1, 2]])
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, ff_activation, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys ff_activation: the non-linearity in feed-forward layer mode: str: 'train' or 'eval' Returns: the layer. """ if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, ff_activation, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _feed_forward(): return [ tl.ReversibleHalfResidual(_FF(), name='ReversibleHalfResidualDecoderFF'), tl.ReversibleSwap() ]