def PositionalEncoding(mode, dropout=None, max_len=None, axial_pos_shape=None, d_axial_pos_embs=None): """Returns the positional encoding layer depending on the arguments.""" if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def PositionalEncoder(mode, dropout=None, max_len=None, pos_type=None, pos_axial_shape=None, pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, use_bfloat16=False): """Returns the positional encoding layer depending on the arguments. Args: mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the embedding block. max_len: Maximum symbol length for positional encoding. pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. pos_start_from_zero_prob: how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize). pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. use_bfloat16: If `True`, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues. Returns: A layer that will do the positional encoding. """ if not pos_type: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, use_bfloat16=use_bfloat16, start_from_zero_prob=pos_start_from_zero_prob, max_offset_to_add=pos_max_offset_to_add, mode=mode) elif pos_type == 'sin-cos': positional_encoding = tl.SinCosPositionalEncoding(mode=mode) elif pos_type == 'fixed-base': positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif pos_type == 'infinite': positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif pos_type == 'infinite-affine': positional_encoding = tl.InfinitePositionalEncoding() elif pos_type == 'time-bin': positional_encoding = tl.TimeBinPositionalEncoding() elif pos_type == 'no': positional_encoding = tl.Serial() # no positional encoding at all else: # TODO(lukaszkaiser): name this type and check for the correct name assert pos_d_axial_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=pos_axial_shape, d_embs=pos_d_axial_embs, dropout_broadcast_dims=tuple(range(1, len(pos_axial_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def PositionalEncoder(mode, dropout=None, max_len=None, axial_pos_shape=None, d_axial_pos_embs=None, use_bfloat16=False): """Returns the positional encoding layer depending on the arguments. Args: mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the embedding block. max_len: Maximum symbol length for positional encoding. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. use_bfloat16: If `True`, use bfloat16 weights instead of the default float32; this can save memory but may (rarely) lead to numerical issues. Returns: A layer that will do the positional encoding. """ if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout, mode=mode, use_bfloat16=use_bfloat16) elif axial_pos_shape == 'sin-cos': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.SinCosPositionalEncoding(mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return positional_encoding
def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ]
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn(lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn(lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax() )
def ReformerLM(vocab_size, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_chunks=0, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model (only uses a decoder, no encoder). Args: vocab_size: int: vocab size d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_chunks: int: number of chunks (must match input pipeline) n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train', 'eval', or 'predict' Returns: the layer. """ if n_chunks == 0: n_chunks = 1 concatenate_input_chunks = [] else: concatenate_input_chunks = tl.Concatenate(n_items=n_chunks) d_emb = d_model if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) elif axial_pos_shape == 'fixed-base': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.FixedBasePositionalEncoding(mode=mode) d_emb //= 2 elif axial_pos_shape == 'infinite': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding(affine=False) elif axial_pos_shape == 'infinite-affine': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.InfinitePositionalEncoding() elif axial_pos_shape == 'time-bin': # TODO(lukaszkaiser): remove this HACK positional_encoding = tl.TimeBinPositionalEncoding() else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_emb, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) return tl.Serial( concatenate_input_chunks, tl.ShiftRight(mode=mode), positional_embedder, tl.Dup(), tl.ReversibleSerial(decoder_blocks + [ SplitForOutput(n_sections=n_chunks, axis=-2), # pylint: disable=no-value-for-parameter ]), Map([ # TODO(kitaev): Test whether dropout should go before or after the # LayerNorm, and whether dropout broadcasting is needed here. tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(vocab_size), tl.LogSoftmax(), ], n_sections=n_chunks), )