def test_causal_conv(self): layer = tl.CausalConv(filters=30, kernel_width=3) x = np.ones((9, 5, 20)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (9, 5, 30))
def test_causal_conv_use_bias_false(self): layer = tl.CausalConv(filters=30, kernel_width=3, use_bias=False) x = np.ones((9, 5, 20)) layer.init(shapes.signature(x)) y = layer(x) self.assertEqual(y.shape, (9, 5, 30)) self.assertEqual(layer.weights.shape, (3, 20, 30))
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, ff_chunk_size=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ assert mode != 'predict' # TODO(lukaszkaiser,kitaev): fast inference if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn(lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Select([0], n_in=2), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn(lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax() )
def FunnelTransformerLM(vocab_size, d_model=512, d_ff=2048, vanilla_layers=(0, 1), shorten_factors=(3,), n_funnel_blocks=(6,), n_heads=8, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.FastGelu): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level Transformer decoder layers before and after shortening. shorten_factors: by how much to shorten at each step - tuple of arbitrary length denoting by how much shorten at each pooling stage. n_funnel_blocks: number of Transformer decoder blocks after each stage of pooling - tuple of the same length as `shorten_factors`. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: str: 'train' or 'eval'. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ assert mode != 'predict' # For now, 'predict' mode is unsupported. assert len(n_funnel_blocks) == len(shorten_factors) token_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)] context_bias_layer, location_bias_layer = _get_rel_att_inputs(d_model, n_heads) n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers def create_decoder_blocks(n_layers, total_pooling): # pylint: disable=invalid-name decoder_blocks = [ # pylint: disable=g-complex-comprehension _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling) for _ in range(n_layers)] return decoder_blocks + [tl.LayerNorm()] total_pooling_acc = 1 pre_decoder_blocks = create_decoder_blocks(n_pre_decoder_blocks, total_pooling=1) funnel_blocks = [] for shorten_factor, block_len in zip(shorten_factors, n_funnel_blocks): funnel_blocks = funnel_blocks + [_FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=shorten_factor, resampler_fn=_DownsamplerLM)] total_pooling_acc *= shorten_factor funnel_blocks = funnel_blocks + create_decoder_blocks(block_len, total_pooling_acc) upsampling_layer = _FunnelRelativeDecoderBlock( d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer=context_bias_layer, location_bias_layer=location_bias_layer, total_pooling=total_pooling_acc, shorten_factor=total_pooling_acc, resampler_fn=_UpsamplerLM) conv_layer = tl.Serial( tl.CausalConv(d_model, total_pooling_acc), ff_activation() ) post_decoder_blocks = create_decoder_blocks(n_post_decoder_blocks, total_pooling=1) # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks token_encoder, # vecs pre_decoder_blocks, # vecs tl.Dup(), tl.ShiftRight(n_positions=total_pooling_acc - 1), funnel_blocks, tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), upsampling_layer, tl.LayerNorm(), tl.Concatenate(), conv_layer, post_decoder_blocks, tl.Dense(vocab_size), # vecs )
def RelformerLM(vocab_size, d_model=512, d_ff=2048, vanilla_layers=(1, 1), shorten_factor=3, n_rel_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, vanilla_attn_type=tl.LSHSelfAttention, pos_type='fixed-base', max_len=3072, n_raw_tokens_generated=1, mode='train', ff_activation=tl.FastGelu): """Returns a Transformer language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). This model uses only the decoder part of the overall Transformer. Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. vanilla_layers: (pre_layers, post_layers) tuple - number of full token-level Transformer decoder layers before and after shortening. shorten_factor: by how much to shorten n_rel_layers: number of Transformer blocks after the pooling. These blocks use relative attention. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. vanilla_attn_type: class: attention class such as SelfAttention to use in the layers before and after shortening (vanilla layers). pos_type: string, the type of positional embeddings to use. max_len: int: maximum symbol length both for positional encoding and it is also the maximum length of the possible inference in 'predict' mode n_raw_tokens_generated: int: number of tokens generated with every pass through model in 'predict' mode. Number of tokens should be smaller and divisible by the first shorten factor we are using in the model. It cannot be larger than one if we use vanilla layers because we would lose autoregressive property of the model. mode: str: 'train' or 'eval' or 'predict'. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ token_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) ] positional_encoder = PositionalEncoder(mode, dropout, max_len, pos_type) n_pre_decoder_blocks, n_post_decoder_blocks = vanilla_layers def create_decoder_blocks(n_layers, total_pooling): # pylint: disable=invalid-name context_bias_layer, location_bias_layer = _get_rel_att_inputs( d_model, n_heads) decoder_blocks = [ # pylint: disable=g-complex-comprehension _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, context_bias_layer, location_bias_layer, total_pooling, max_len) for _ in range(n_layers) ] return decoder_blocks + [tl.LayerNorm()] def create_reformer_blocks(n_layers, dense=True): # pylint: disable=invalid-name if n_layers == 0: return [tl.LayerNorm()] d_per_head = d_model // n_heads decoder_blocks = [ DecoderBlock( d_model, d_ff, d_per_head, d_per_head, n_heads, # pylint: disable=g-complex-comprehension vanilla_attn_type, dropout, ff_activation, dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, mode=mode) for _ in range(n_layers) ] return [ tl.Dup(), tl.ReversibleSerial(decoder_blocks), tl.Concatenate(), tl.LayerNorm(), tl.Dense(d_model) if dense else [], ] pre_decoder_blocks = create_reformer_blocks(n_pre_decoder_blocks, dense=True) relative_decoder_blocks = create_decoder_blocks(n_rel_layers, shorten_factor) conv_layer = tl.Serial(tl.CausalConv(d_model, shorten_factor), ff_activation()) post_decoder_blocks = create_reformer_blocks(n_post_decoder_blocks, dense=False) cacher = RelformerCacher(total_kv_pooling=shorten_factor, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_len, shift=shorten_factor - 1, mode=mode) picker = RelformerPicker(total_kv_pooling=shorten_factor, n_raw_tokens_generated=n_raw_tokens_generated, mode=mode) cacher_conv = RelformerCacher( total_kv_pooling=shorten_factor, n_raw_tokens_generated=n_raw_tokens_generated, max_inference_length=max_len, shift=shorten_factor - 1, sliding=True, mode=mode) picker_conv = PickLastTokenInPredict(mode=mode) # Assemble and return the model. return tl.Serial( # tokens (or chunked tuple of tokens) tl.ShiftRight(mode=mode), # toks token_encoder, # vecs positional_encoder, pre_decoder_blocks, # vecs tl.Dup(), cacher, tl.ShiftRight(n_positions=shorten_factor - 1, mode=mode), _DownsamplerLM(shorten_factor, d_model), relative_decoder_blocks, tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), _UpsamplerLM(shorten_factor, d_model), tl.LayerNorm(), picker, tl.Concatenate(), cacher_conv, conv_layer, picker_conv, post_decoder_blocks, tl.Dense(vocab_size), # vecs )