def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: the layer. """ enc_dec_attention = tl.EncDecAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' Returns: the layer. """ enc_dec_attention = tl.EncDecAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = ReversibleHalfResidualV2( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention(n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = ReversibleHalfResidualV2( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), ReversibleHalfResidualV2(feed_forward), tl.ReversibleSwap(), ]