def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, n_attention_layers=1, n_feedforward_layers=1, center_layernorm=True, use_bfloat16=False, mode='train'): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size n_attention_layers: how many residual causal attention layers should we have before the feed-forward block (default: 1, the standard block) n_feedforward_layers: how many FFNN layers should we have (default 1). center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False). mode: str: 'train' or 'eval' Returns: the layer. """ # pylint: disable=g-complex-comprehension attention_half_residuals = [ [tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=ct.ApplyAttentionLayer( attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode), name='ReversibleHalfResidualDecoderAttn'), tl.ReversibleSwap() ] for _ in range(n_attention_layers)] feed_forwards = [ [tl.ReversibleHalfResidual( ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16), name='ReversibleHalfResidualDecoderFF'), tl.ReversibleSwap() ] for _ in range(n_feedforward_layers)] # pylint: enable=g-complex-comprehension return attention_half_residuals + feed_forwards
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate for feed-forward layer mode: str: 'train' or 'eval' ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: the layer. """ enc_dec_attention = tl.EncDecAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, attention_dropout=dropout, output_dropout=dropout, mode=mode) enc_dec_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=enc_dec_attention, ) causal_attention = tl.SelfAttention( n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, causal=True, attention_dropout=dropout, output_dropout=dropout, mode=mode) causal_attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=causal_attention, ) feed_forward = ct.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ # vec_d1 vec_d2 vec_e masks causal_attention_half_residual, tl.ReversibleSwap(), enc_dec_attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size mode: str: 'train' or 'eval' Returns: the layer. """ attention = ct.ApplyAttentionLayer(attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, use_bfloat16=False, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size use_bfloat16: whether to use bfloat16 for weights (default: False) mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. if attention.n_out == 2: def reshape_mask(mask): return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) attention = tl.Serial( tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2), attention, tl.Select([0], n_in=2)) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, use_bfloat16) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _FF(): return ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16)
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = configurable_transformer.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = configurable_transformer.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]