def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru, ff_chunk_size, ff_sparsity, attention_chunk_size, n_attention_layers=1, n_feedforward_layers=1, use_bfloat16=False, mode='train'): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size n_attention_layers: how many residual causal attention layers should we have before the feed-forward block (default: 1, the standard block) n_feedforward_layers: how many FFNN layers should we have (default 1). use_bfloat16: whether to use bfloat16 for weights (default: False). mode: str: 'train' or 'eval' Returns: the layer. """ # pylint: disable=g-complex-comprehension attention_half_residuals = [[ tl.ReversibleHalfResidual(tl.LayerNorm(), attention_layer=ct.ApplyAttentionLayer( attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode)), tl.ReversibleSwap() ] for _ in range(n_attention_layers)] feed_forwards = [[ tl.ReversibleHalfResidual( ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode, use_bfloat16)), tl.ReversibleSwap() ] for _ in range(n_feedforward_layers)] # pylint: enable=g-complex-comprehension return attention_half_residuals + feed_forwards
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, center_layernorm=True, use_bfloat16=False, use_two_swaps_per_block=True, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False) use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder block, otherwise use only one swap. mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. if attention.n_out == 2: def reshape_mask(mask): return jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) attention = tl.Serial( tl.Fn('ReshapeMask', lambda x, y: (x, reshape_mask(y)), n_out=2), attention, tl.Select([0], n_in=2)) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=attention, name='ReversibleHalfResidualEncoderAttn') feed_forward = ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16) encoder_block = [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward, name='ReversibleHalfResidualEncoderFF'), ] if use_two_swaps_per_block: encoder_block.append(tl.ReversibleSwap()) return encoder_block
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' attention = configurable_transformer.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, mode=mode) attention_half_residual = tl.ReversibleHalfResidual( tl.LayerNorm(), attention_layer=attention, ) feed_forward = configurable_transformer.FeedForwardWithOptions( d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, mode) return [ attention_half_residual, tl.ReversibleSwap(), tl.ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def _Attn(): return ct.ApplyAttentionLayer(attention_type, d_model, n_heads, d_attention_key, d_attention_value, True, False, dropout, dropout, attention_chunk_size, mode)
def _Attn(): return ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode)