def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ if not hasattr(attention_type, 'forward_unbatched'): if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads( n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] attention_half_residual = ReversibleAttentionHalfResidual( pre_attention, attention, post_attention) else: attention = attention_type( n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value, share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode) attention_half_residual = ReversibleHalfResidualV2( tl.LayerNorm(), attention_layer=attention, ) if ff_use_sru: feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)] else: feed_forward = [ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, dropout, ff_chunk_size, mode)] return [ attention_half_residual, tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys mode: str: 'train' or 'eval' Returns: the layer. """ if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] feed_forward = [ FeedForward(d_model, d_ff, dropout, mode=mode), ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]