def test_state(self): model = tl.Parallel(tl.Dense(3), tl.Dense(5)) self.assertIsInstance(model.state, tuple) self.assertLen(model.state, 2)
def BERT( d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel(word_embeddings, type_embeddings, [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1) ]), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model)] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT(embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, forget_dense=True, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, center_layernorm=True, half_before_layer=None, double_after_layer=None, mode='train'): """Reversible transformer encoder-decoder model. If input_vocab_size is not None, this model expects an input pair: source, target. Otherwise, it expects a triple: embedded_source, mask, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. pos_start_from_zero_prob: how often to start from 0 during training, (if 1.0, we always start from position 0, if less, we randomize). pos_max_offset_to_add: maximum offset to add to positions during training when randomizing; this offset plus input length must still be less than max_len for all training examples. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers forget_dense: whether to use Dense or no-op (Serial) as a forget layer. n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder use_two_swaps_per_encoder_block: whether to allow even number of swaps in the encoder center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. half_before_layer: int, half d_model and d_ff before that layer double_after_layer: int, double d_model and d_ff after that layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if (d_model / 2) % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') if d_attention_key is None: d_attention_key = d_model // n_heads if d_attention_value is None: d_attention_value = d_model // n_heads # Set values of d_model, d_ff and d_qkv for the first stage. d_model1, d_ff1 = d_model, d_ff d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value if half_before_layer: d_model1, d_ff1 = d_model / 2, d_ff / 2 d_attention_key1 = d_attention_key / 2 d_attention_value1 = d_attention_value / 2 # Set values of d_model, d_ff and d_qkv for the final stage. d_model2, d_ff2 = d_model, d_ff d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value if double_after_layer: d_model2, d_ff2 = d_model * 2, d_ff * 2 d_attention_key2 = d_attention_key * 2 d_attention_value2 = d_attention_value * 2 # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model1, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, pos_start_from_zero_prob=pos_start_from_zero_prob, pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model1, d_ff1, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget, forget_dense) ] if not reversible_encoder: encoder += [ _XYAvg(), tl.Dense(d_model1, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] # Grow d_model, d_ff, and d_qkv if requested. d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 if half_before_layer and layer_idx >= half_before_layer: d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value if double_after_layer and layer_idx > double_after_layer: d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 decoder_block = DecoderBlock( d_m, d_f, d_k, d_v, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) if half_before_layer and layer_idx == half_before_layer - 1: decoder_blocks.append(tl.ReversibleConcatenatePair()) if double_after_layer and layer_idx == double_after_layer: decoder_blocks.append(tl.ReversibleConcatenatePair()) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model2, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] if input_vocab_size is not None: # Input in this case is tok_e, tok_d. mask_layers = [ tl.PaddingMask(), _RemoveAxes12(), ] inp_layers = tl.Serial([ tl.Select([0, 0, 0, 1]), # tok_e tok_e tok_e tok_d tl.Parallel(in_encoder, mask_layers) # vec_e mask_e tok_e tok_d ]) inp_layers = tl.AssertFunction('bt,bu->btf,bt,bt,bu', inp_layers) else: # Input in this case is vec_e, mask_e, tok_d. Where all downstream # operations expect tok_e, we give it instead mask_e, expecting that # downstream ops only are looking for padding/not padding. inp_layers = tl.Serial([ tl.Select([0, 1, 1, 2]), # vec_e mask_e tok_e tok_d tl.Parallel(in_encoder, [], _AsTokenIDs()) # vec_e mask_e tok_e tok_d ]) inp_layers = tl.AssertFunction('btg,bt,bu->btf,bt,bt,bu', inp_layers) # Assemble and return the model. return tl.Serial( inp_layers, # vec_e mask_e tok_e tok_d # Copy decoder tokens for use in loss. tl.Select([0, 1, 2, 3, 3]), # vec_e mask_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( [], [], [], # vec_e mask_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), # Predict mode doesn't work with padding in encoder. Raising an exception # in jitted function isn't possible, so the second next best thing is # to convert every embedding to NaNs, so the user will not get subtly # wrong results, but clearly wrong results. (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []), # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, forget_dense), # vec_ed1 vec_ed2 tok_e tok_d _XYAvg(), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention pos_type: string, the type of positional embeddings to use. pos_axial_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match pos_axial_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity loss_sparsity_type: str, type of sparsity to used in loss layer. See SparseDenseWithOptions for options. None if no sparsity should be used. loss_sparsity: int, the sparsity for loss layer (if used) loss_d_lowrank: int, the dimensions for intermediate layer (if used) loss_sparsity_prob: float, the probability for sparse version of loss to be used. If None, only sparse version is used. attention_chunk_size: int, if > 0 run attention chunked at this size n_layers_forget: how often to have a forgetting block between layers n_decoder_attention_layers: how many attention layers in a decoder block use_bfloat16: whether to use bfloat16 for weights (default: False) reversible_encoder: whether to be reversible through the encoder use_two_swaps_per_encoder_block: whether to allow even number of swaps in the encoder mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError(f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, use_bfloat16=use_bfloat16) ) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = [ # vec_e mask_e tok_e tok_d tok_d tl.ReversibleSelect([0, 0]), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget) ] if not reversible_encoder: encoder += [ tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] encoder = tl.Serial(encoder) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) dense_loss_layer = tl.SparseDenseWithOptions( output_vocab_size, d_input=d_model, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) # Layers to merge encoder and decoder, see below for details. if reversible_encoder: encdec_layers = [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # vec_ed vec_ed tok_e tok_d ] else: encdec_layers = [ tl.ReversibleSelect([0, 3, 1, 2]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # vec_ed vec_ed tok_e tok_d ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel(in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Concat encoder and decoder, given encoder mask. encdec_layers, # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. dense_loss_layer, # vec_d tok_d )
def FunnelTransformer(vocab_size, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_decoder_blocks=2, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2, ), separate_cls=True): """Returns a Full Funnel Transformer, that can be used for example for BERT. This model outputs token-level categorical distributions over all vocab: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions over `vocab_size` categories for each token; shape is (batch_size, sequence_length, vocab_size). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_decoder_blocks: Number of transformer blocks in the upsampling decoder. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] n_encoder_segments = len(encoder_segment_lengths) encoder_blocks_before_first_pooling = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(encoder_segment_lengths[0]) ] encoder_blocks_from_first_pooling = [] for i in range(1, n_encoder_segments): # Building i'th segment # Add funnel block between segments encoder_blocks_from_first_pooling.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size=pool_size, strides=pool_size, separate_cls=separate_cls)) for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks_from_first_pooling.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) decoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_decoder_blocks) ] total_pool_size = pool_size[0]**(len(encoder_segment_lengths) - 1) # Assemble and return the model. return tl.Serial( # toks tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks_before_first_pooling, # vecs masks tl.Select([0, 1, 0, 1]), # vecs masks residual = vecs old_masks encoder_blocks_from_first_pooling, # vecs masks residual masks tl.Select([0, 2, 3]), # vecs residual masks tl.Parallel( # residual from first segment is taken before # normalization, so apply it now None, tl.LayerNorm(), None), # vecs norm(residual) masks _Upsampler(total_pool_size, separate_cls), # vecs masks decoder_blocks, tl.Select([0], n_in=2), # vecs tl.LayerNorm(), tl.Dense(vocab_size), )
def SerializedPolicy(seq_model, n_controls, n_actions, observation_serializer, action_serializer): """Wraps a policy in serialization machinery for training. The resulting model takes as input observation and action sequences, and serializes them into one sequence similar to SerializedModel, before passing to the given sequence model. Adds output heads for action logits and value predictions. Args: seq_model: Trax sequence model taking as input a sequence of symbols and outputting a sequence of continuous vectors. n_controls: Number of controls. n_actions: Number of action categories in each control. observation_serializer: Serializer to use for observations. action_serializer: Serializer to use for actions. Returns: A model of signature (obs, act) -> (act_logits, values), same as in RawPolicy. """ if action_serializer.representation_length != n_controls: raise ValueError( 'Action symbols should correspond 1-1 to controls, but got {} ' 'controls and {} symbols.'.format( n_controls, action_serializer.representation_length)) def FirstSymbol(): return tl.Fn('FirstSymbol', lambda x: x[:, :, 0]) def PadRight(n_to_pad): def pad_right(x): pad_widths = [(0, 0), (0, n_to_pad)] + [(0, 0)] * (x.ndim - 2) return jnp.pad(x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) return tl.Fn(f'PadRight({n_to_pad})', pad_right) action_head = [ tl.Dense(n_actions), tl.LogSoftmax(), ] value_head = [ # Take just the vectors corresponding to the first action symbol. FirstSymbol(), # Predict values. tl.Dense(1), # Get rid of the singleton dimension. tl.Flatten(), ] return tl.Serial( # (obs, act) tl.Parallel(Serialize(observation_serializer), Serialize(action_serializer)), # (obs_repr, act_repr) Interleave(), # (obs_act_repr,) # Add one dummy action to the right - we'll use the output at its first # symbol to predict the value for the last observation. PadRight(action_serializer.representation_length), # Shift one symbol to the right, so we predict the n-th action symbol # based on action symbols 1..n-1 instead of 1..n. tl.ShiftRight(), seq_model, # (obs_act_hidden,) Deinterleave(observation_serializer.representation_length, action_serializer.representation_length), # (obs_hidden, act_hidden) tl.Select([1, 1]), # (act_hidden, act_hidden) tl.Parallel(action_head, value_head), # (act_logits, values) )
def ConfigurableTerraformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, pos_type='fixed-base', pos_axial_shape=(), pos_d_axial_embs=None, pos_start_from_zero_prob=1.0, pos_max_offset_to_add=0, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, loss_sparsity_type='mult', loss_sparsity=0, loss_d_lowrank=0, loss_sparsity_prob=None, attention_chunk_size=0, n_layers_forget=0, forget_dense=True, n_decoder_attention_layers=2, use_bfloat16=False, reversible_encoder=False, use_two_swaps_per_encoder_block=True, center_layernorm=True, half_before_layer=None, double_after_layer=None, mode='train'): """Returns a highly configurable Terraformer encoder-decoder model. This model maps paired text sequences (source and target) to float-valued losses. If ``input_vocab_size`` is not ``None``, the layer takes two input sequences: - inputs (2): - source: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(input_vocab_size)``, and 0 values mark padding positions. - target: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 1-D float array of losses; shape is `(batch_size)`. If ``input_vocab_size`` is ``None``, the layer takes three input sequences: - inputs (3): - source: 3-D float array representing a batch of already-embedded text strings; shape is `(batch_size, sequence_length, d_model)`, where sequence_length <= ``max_len``. - mask: 2-D int array representing active versus masked positions; 0 values mark masked (padding) positions. - target: 2-D int array representing a batch of text strings via token IDs plus padding markers; shape is `(batch_size, sequence_length)`, where sequence_length <= ``max_len``. Array elements are in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 1-D float array of losses; shape is `(batch_size)`. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in ``range(vocab_size)``. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if ``None``, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each encoder block. d_attention_key: Depth of key vectors in each attention head. d_attention_value: Depth of value vectors in each attention head. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within encoder/decoder blocks. The same rate is also used for attention dropout in encoder/decoder blocks. max_len: Maximum symbol length for positional encoding. encoder_attention_type: Type of attention to use in the encoder; must be an attention-type subclass of :py:class:`trax.layers.Layer`. encoder_decoder_attention_type: Type of attention to use in the decoder; must be an attention-type subclass of :py:class:`trax.layers.Layer`. pos_type: String indicating the type of positional embeddings to use. pos_axial_shape: Shape (tuple of ints) to use for the axial position encoding. If unset, axial position encoding is disabled. pos_d_axial_embs: Tuple of ints specifying the depth of position embedding for each axis. Tuple length must match ``pos_axial_shape``, and values must sum to ``d_model``. pos_start_from_zero_prob: Stochastic rate (probability) for starting positional encoding at position 0 during training. If 1.0, always start from position 0; if < 1.0, the non-zero starts will be uniformly distributed up to ``pos_max_offset_to_add``. pos_max_offset_to_add: Maximum offset to add to positions during training when randomizing. This offset plus input length must be less than ``max_len`` for all training examples. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of :py:class:`trax.layers.Layer`. ff_use_sru: If > 0, use this number of SRU layers in place of feedforward layers. ff_chunk_size: If > 0, chunk each feedforward layer into chunks of this size. ff_dropout: Stochastic rate (probability) for dropping an activation value at feedforward nonlinearities. ff_sparsity: If > 0, use sparse feedforward blocks with this level of sparsity. loss_sparsity_type: String indicating the type of sparsity to used in loss layer; see :py:class:`SparseDenseWithOptions` for options. If ``None``, use no sparsity. loss_sparsity: If > 0, use this level of sparsity in the loss layer. loss_d_lowrank: If > 0, use a (low-rank) intermediate layer, with this dimension, in the loss. loss_sparsity_prob: Stochastic rate (probability) for using the sparse version of the loss. If ``None``, use the sparse version exclusively. attention_chunk_size: If > 0, compute attention using chunks of this size. n_layers_forget: How often to have a forgetting block between layers. forget_dense: If True, use :py:class:`Dense` instances as forget layers; else use no-ops. n_decoder_attention_layers: Number of attention layers in a decoder block. use_bfloat16: If True, use bfloat16 for weights; else use float32. reversible_encoder: If True, make the encoder be reversible. use_two_swaps_per_encoder_block: If True, ensure that there is a an even number of swaps across the encoder. center_layernorm: If True, use centering in :py:class:`LayerNorm` (the default); else omit centering (which is known as RMS normalization). half_before_layer: If not None, specifies an n'th layer such that all layers before the n'th use half the normal values for ``d_model`` and ``d_ff``. double_after_layer: If not None, specifies an n'th layer such that all layers after the n'th use double the normal values for ``d_model`` and ``d_ff``. mode: If ``'train'``, include dropout in each encoder/decoder block; else dropout layers have no effect. Returns: A Terraformer encoder-decoder as a layer that maps from target and source text sequences to a scalar loss. """ # Set default dimensions for attention head key and value sizes. if (d_model / 2) % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})') if d_attention_key is None: d_attention_key = d_model // n_heads if d_attention_value is None: d_attention_value = d_model // n_heads # Set values of d_model, d_ff and d_qkv for the first stage. d_model1, d_ff1 = d_model, d_ff d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value if half_before_layer: d_model1, d_ff1 = d_model / 2, d_ff / 2 d_attention_key1 = d_attention_key / 2 d_attention_value1 = d_attention_value / 2 # Set values of d_model, d_ff and d_qkv for the final stage. d_model2, d_ff2 = d_model, d_ff d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value if double_after_layer: d_model2, d_ff2 = d_model * 2, d_ff * 2 d_attention_key2 = d_attention_key * 2 d_attention_value2 = d_attention_value * 2 # Vector embeddings. in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model1, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, pos_type=pos_type, pos_axial_shape=pos_axial_shape, pos_d_axial_embs=pos_d_axial_embs, pos_start_from_zero_prob=pos_start_from_zero_prob, pos_max_offset_to_add=pos_max_offset_to_add, use_bfloat16=use_bfloat16)) def _EncoderBlock(): return reformer.EncoderBlock( d_model1, d_ff1, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, use_two_swaps_per_block=use_two_swaps_per_encoder_block, mode=mode) def _Encoder(): # vec_e mask_e tok_e tok_d tok_d layers = [ tl.ReversibleSelect([0, 0]), _ReversibleSerialForget( [_EncoderBlock() for _ in range(n_encoder_layers)], d_model1, n_layers_forget, forget_dense) ] if not reversible_encoder: layers += [ _XYAvg(), tl.Dense(d_model1, use_bfloat16=use_bfloat16), tl.LayerNorm(), ] if mode == 'predict': return tl.Cache(tl.Serial(layers)) else: return tl.Serial(layers) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] # Grow d_model, d_ff, and d_qkv if requested. d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1 if half_before_layer and layer_idx >= half_before_layer: d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value if double_after_layer and layer_idx > double_after_layer: d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2 decoder_block = reformer.DecoderBlock( d_m, d_f, d_k, d_v, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, attention_chunk_size=attention_chunk_size, n_attention_layers=n_decoder_attention_layers, center_layernorm=center_layernorm, use_bfloat16=use_bfloat16, mode=mode) decoder_blocks.append(decoder_block) if half_before_layer and layer_idx == half_before_layer - 1: decoder_blocks.append(tl.ReversibleConcatenatePair()) if double_after_layer and layer_idx == double_after_layer: decoder_blocks.append(tl.ReversibleConcatenatePair()) def _Loss(): return tl.SparseDenseWithOptions(output_vocab_size, d_input=d_model2, sparsity_type=loss_sparsity_type, sparsity=loss_sparsity, d_lowrank=loss_d_lowrank, prob_sparse=loss_sparsity_prob, use_bfloat16=use_bfloat16, mode=mode) def _enc_dec_concat(): """Layers to merge encoder and decoder.""" if reversible_encoder: return [ tl.ReversibleSelect([0, 1, 4, 2, 3]), # v_e v_d mask_e tok_e tok_d t2.ConcatWithPadding2(mode=mode), # v_ed v_ed tok_e tok_d ] else: return [ tl.ReversibleSelect([0, 3, 1, 2]), # v_e v_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # v_ed tok_e tok_d tl.ReversibleSelect([0, 0]), # v_ed v_ed tok_e tok_d ] def _inp_layers(): if input_vocab_size is not None: return tl.AssertFunction( 'bl,br->bld,bl,bl,br', # b: batch, l/r: enc/dec length, d: vec depth tl.Serial( # tok_e tok_d tl.Select([0, 0, 0, 1]), tl.Parallel( in_encoder, [tl.PaddingMask(), _RemoveAxes12() ]))) # vec_e mask_e tok_e tok_d else: # Input in this case is vec_e, mask_e, tok_d. Where all downstream # operations expect tok_e, we give it instead mask_e, expecting that # downstream ops only are looking for padding/not padding. return tl.AssertFunction( 'blf,bl,br->bld,bl,bl,br', # f: in-feature depth, d: out-vector depth tl.Serial( # vec_e mask_e tok_d tl.Select([0, 1, 1, 2]), tl.Parallel(in_encoder, [], _AsTokenIDs()))) # vec_e mask_e tok_e tok_d # Assemble and return the model. return tl.Serial( _inp_layers(), # vec_e mask_e tok_e tok_d tl.Select([0, 1, 2, 3, 3]), # Copy decoder tokens for use in loss. # Embed in and out tokens; done together as weights may be shared. tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), out_encoder ]), # vec_e mask_e tok_e vec_d tok_d # Predict mode doesn't work with padding in encoder. Raising an exception # in jitted function isn't possible, so the next best thing is to convert # every embedding to NaNs, so the user will get unmistakably wrong # results. (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []), # Encode; then concat encoder and decoder, given encoder mask. _Encoder(), # vec_e mask_e tok_e vec_d tok_d _enc_dec_concat(), # Run decoder blocks. _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget, forget_dense), # vec_ed1 vec_ed2 tok_e tok_d _XYAvg(), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector, # then compute loss. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d _Loss(), # vec_d tok_d )
def policy_and_value_net(bottom_layers_fn, observation_space, action_space, vocab_size, two_towers): """A policy and value net function. Runs bottom_layers_fn either as a single network or as two separate towers. Attaches action and value heads and wraps the network in a policy wrapper. Args: bottom_layers_fn: Trax model to use as a policy network. observation_space (gym.Space): Observation space. action_space (gym.Space): Action space. vocab_size (int or None): Vocabulary size to use with a SerializedPolicy wrapper. If None, RawPolicy will be used. two_towers (bool): Whether to run bottom_layers_fn as two separate towers for action and value prediction. Returns: Pair (network, substitute_fn), where network is the final network and substitute_fn is a function (wrapped_tree, inner_tree) -> wrapped_tree for substituting weights or state of the constructed model based on the weights or state of a model returned from bottom_layers_fn. substitute_fn is used for initializing the policy from parameters of a world model. """ kwargs = {} if vocab_size is not None: kwargs['vocab_size'] = vocab_size def wrapped_policy_fn(): return serialization_utils.wrap_policy( bottom_layers_fn(**kwargs), observation_space, action_space, vocab_size, ) # Now, with the current logits, one head computes action probabilities and the # other computes the value function. # NOTE: The LogSoftmax instead of the Softmax because of numerical stability. if two_towers: # Two towers: run two two-head networks in parallel and drop one head from # each. net = tl.Serial([ # (obs, act) tl.Select([0, 1, 0, 1]), # (obs, act, obs, act) tl.Parallel( wrapped_policy_fn(), wrapped_policy_fn(), ), # (act_logits_1, vals_1, act_logits_2, vals_2) tl.Select([0, 3]), # (act_logits_1, vals_2) ]) def substitute_fn(wrapped_policy, inner_policy): return (wrapped_policy[:1] + [ tuple( # Substitute in both towers. serialization_utils.substitute_inner_policy( # pylint: disable=g-complex-comprehension tower, inner_policy, vocab_size) for tower in wrapped_policy[1]) ] + [wrapped_policy[2:]]) else: # One tower: run one two-headed network. net = wrapped_policy_fn() substitute_fn = functools.partial( serialization_utils.substitute_inner_policy, vocab_size=vocab_size, ) return (net, substitute_fn)
def test_dup_dup(self): layer = tl.Parallel(tl.Dup(), tl.Dup()) xs = [np.array([1, 2, 3]), np.array([10, 20])] ys = layer(xs) self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [10, 20], [10, 20]])
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Dup(), # ______ toks_e toks_e tl.Parallel(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Dup(), # vecs_d vecs_d ..... ...... tl.Parallel([], tl.EncoderDecoderMask()), # ______ masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )
def LSTMSeq2SeqAttn(input_vocab_size=256, target_vocab_size=256, d_model=512, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=1, attention_dropout=0.0, mode='train'): """Returns an LSTM sequence-to-sequence model with attention. The input to the model is a pair (input tokens, target tokens), e.g., an English sentence (tokenized) and its translation into German (tokenized). The model works as follows: * Input encoder runs on the input tokens and creates activations that are used as both keys and values in attention. * Pre-attention decoder runs on the targets and creates activations that are used as queries in attention. * Attention runs on the queries, keys and values masking out input padding. * Decoder runs on the result, followed by a cross-entropy loss. Args: input_vocab_size: int: vocab size of the input target_vocab_size: int: vocab size of the target d_model: int: depth of embedding (n_units in the LSTM cell) n_encoder_layers: int: number of LSTM layers in the encoder n_decoder_layers: int: number of LSTM layers in the decoder after attention n_attention_heads: int: number of attention heads attention_dropout: float, dropout for the attention layer mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference Returns: An LSTM sequence-to-sequence model with attention. """ input_encoder = tl.Serial( tl.Embedding(d_model, input_vocab_size), [tl.LSTM(d_model) for _ in range(n_encoder_layers)], ) pre_attention_decoder = tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(d_model, target_vocab_size), tl.LSTM(d_model), ) def PrepareAttentionInputs(): """Layer that prepares queries, keys, values and mask for attention.""" def F(encoder_activations, decoder_activations, input_tokens): keys = values = encoder_activations queries = decoder_activations # Mask is 1 where inputs are not padding (0) and 0 where they are padding. mask = (input_tokens != 0) # We need to add axes to the mask for attention heads and decoder length. mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) return queries, keys, values, mask return tl.Fn('PrepareAttentionInputs', F, n_out=4) return tl.Serial( # in-toks, target-toks tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks tl.Parallel(input_encoder, pre_attention_decoder), PrepareAttentionInputs(), # q, k, v, mask, target-toks tl.Residual( tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)), # decoder-vecs, mask, target-toks tl.Select([0, 2]), # decoder-vecs, target-toks [tl.LSTM(d_model) for _ in range(n_decoder_layers)], tl.Dense(target_vocab_size), tl.LogSoftmax())
def LSTMSeq2SeqAttn(input_vocab_size=256, target_vocab_size=256, d_model=512, n_encoder_layers=2, n_decoder_layers=2, n_attention_heads=1, attention_dropout=0.0, mode='train'): """Returns an LSTM sequence-to-sequence model with attention. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. The model works as follows: * Input encoder runs on the input tokens and creates activations that are used as both keys and values in attention. * Pre-attention decoder runs on the targets and creates activations that are used as queries in attention. * Attention runs on the queries, keys and values masking out input padding. * Decoder runs on the result, followed by a cross-entropy loss. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. target_vocab_size: Target vocabulary size. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. n_encoder_layers: Number of LSTM layers in the encoder. n_decoder_layers: Number of LSTM layers in the decoder after attention. n_attention_heads: Number of attention heads. attention_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an attention block. mode: If `'predict'`, use fast inference. If `'train'`, each attention block will include dropout; else, it will pass all values through unaltered. Returns: An LSTM sequence-to-sequence model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ input_encoder = tl.Serial( tl.Embedding(input_vocab_size, d_model), [tl.LSTM(d_model) for _ in range(n_encoder_layers)], ) pre_attention_decoder = tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(target_vocab_size, d_model), tl.LSTM(d_model, mode=mode), ) def PrepareAttentionInputs(): """Layer that prepares queries, keys, values and mask for attention.""" def F(encoder_activations, decoder_activations, input_tokens): keys = values = encoder_activations queries = decoder_activations # Mask is 1 where inputs are not padding (0) and 0 where they are padding. mask = (input_tokens != 0) # We need to add axes to the mask for attention heads and decoder length. mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1])) # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len]. mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1)) mask = mask.astype(jnp.float32) return queries, keys, values, mask return tl.Fn('PrepareAttentionInputs', F, n_out=4) return tl.Serial( # in-toks, target-toks tl.Select([0, 1, 0, 1]), # in-toks, target-toks, in-toks, target-toks tl.Parallel(input_encoder, pre_attention_decoder), PrepareAttentionInputs(), # q, k, v, mask, target-toks tl.Residual( tl.AttentionQKV( d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode, cache_KV_in_predict=True)), # decoder-vecs, mask, target-toks tl.Select([0, 2]), # decoder-vecs, target-toks [tl.LSTM(d_model, mode=mode) for _ in range(n_decoder_layers)], tl.Dense(target_vocab_size), tl.LogSoftmax())
def Dup2(): """Copy first 2 elements of the stack: (a, b, ...) -> (a, b, a, b, ...).""" return [ # Stack is (a, b, ...) tl.Parallel(tl.Dup(), tl.Dup()), # Stack is (a, a, b, b, ...) tl.Parallel([], tl.Swap()) # Stack is (a, b, a, b, ...) ]
def LearnedQP(keys=None, values=None, binary=False): """Get (query, pos), make learned weight of qeury and return with pos.""" return tl.Parallel( tl.Dense(1), QueryPositionKV(keys=keys, values=values, binary=binary), )
def some_layer(): return tl.Parallel(DivideBy(2.0), DivideBy(5.0))
def test_div_div(self): layer = tl.Parallel(DivideBy(0.5), DivideBy(3.0)) xs = [np.array([1, 2, 3]), np.array([30, 60])] ys = layer(xs) self.assertEqual(as_list(ys), [[2, 4, 6], [10, 20]])
def Reformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=None, d_attention_value=None, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape='fixed-base', d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, ff_sparsity=0, n_layers_forget=0, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity n_layers_forget: how often to have a forgetting block between layers mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.is_backend(fastmath.Backend.JAX): jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access # Set default dimensions for attention head key and value sizes. if d_attention_key is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_key = d_model // n_heads if d_attention_value is None: if d_model % n_heads != 0: raise ValueError( f'n_heads ({n_heads}) must divide d_model ({d_model})') d_attention_value = d_model // n_heads # Vector embeddings. def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) def PositionalEnc(mode): return PositionalEncoding(mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [PositionalEnc(encoder_mode)] out_encoder = out_embedder + [PositionalEnc(mode)] if output_vocab_size is None: output_vocab_size = input_vocab_size # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.Dense(d_model), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_dropout=ff_dropout, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 0, 1, 1]), # tok_e tok_e tok_e tok_d tok_d # Embed in and out tokens; done together as weights may be shared. tl.Parallel( in_encoder, [], [], # vec_e tok_e tok_e vec_d tok_d [tl.ShiftRight(mode=mode), out_encoder]), tl.Parallel([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # vec_e mask_e tok_e vec_d tok_d # Encode. encoder, # vec_e mask_e tok_e vec_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # vec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given encoder mask. tl.Select([1, 0]), # vec_e vec_d mask_e tok_e tok_d t2.ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d _ReversibleSerialForget( decoder_blocks, d_model, n_layers_forget), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d t2.StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def test_two_no_ops(self): layer = tl.Parallel([], None) xs = [np.array([1, 2, 3]), np.array([10, 20])] ys = layer(xs) self.assertEqual(as_list(ys), [[1, 2, 3], [10, 20]])
def ReformerShortenLM(vocab_size, shorten_factor=1, d_embedding=256, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_layers=6, n_heads=8, dropout=0.1, max_len=2048, n_attention_chunks=1, attention_type=tl.DotProductCausalAttention, share_qk=False, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.FastGelu, ff_use_sru=0, mode='train'): """Reversible transformer language model with shortening. When shorten_factor is F and processing an input of shape [batch, length], we embed the (shifted-right) input and then group each F elements (on length) into a single vector -- so that in the end we process a tensor of shape [batch, length // F, d_model] almost until the end -- at the end it's un-shortend and a SRU is applied. This reduces the length processed inside the main model body, effectively making the model faster but possibly slightly less accurate. Args: vocab_size: int: vocab size shorten_factor: by how much to shorten, see above d_embedding: the depth of the embedding layer and final logits d_model: int: depth of *each half* of the two-part features d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding n_attention_chunks: int: number of chunks for attention attention_type: class: attention class to use, such as DotProductAttention. share_qk: bool, whether to share queries and keys. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, values must sum to d_embedding. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward mode: str: 'train' or 'eval' Returns: the layer. """ if not axial_pos_shape: positional_encoding = tl.PositionalEncoding(max_len=max_len, dropout=dropout) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout) positional_embedder = [ tl.Embedding(d_embedding, vocab_size), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter positional_encoding, ] decoder_blocks = [] if isinstance(attention_type, (tuple, list)): assert n_layers % len(attention_type) == 0 else: attention_type = [attention_type] for layer_idx in range(n_layers): layer_attention_type = attention_type[layer_idx % len(attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type=layer_attention_type, dropout=dropout, share_qk=(share_qk or issubclass(layer_attention_type, tl.LSHCausalAttention)), ff_activation=ff_activation, ff_use_sru=ff_use_sru, mode=mode) decoder_blocks.append(decoder_block) # pylint: disable=g-long-lambda return tl.Serial( tl.ShiftRight(), positional_embedder, tl.Dup(), # Stack has (x, x), the first will be shortened # Before shortening, we need to pad by shorten factor so as not to leak # information into the future. To understand why, imagine shorten factor # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we # would have 0ABC, which gets grouped to [0A][BC] on input, which is # predicting ABCD as targets. The problem is that [0A] has access to A # and [BC] has access to C -- it will learn to copy it, peek into # the future. Shifting twice to [00][AB] solves the problem as the first # "big" symbol becomes all-0 and the rest is shifted enough. tl.ShiftRight(n_shifts=shorten_factor - 1), tl.Fn( lambda x: np.reshape( # Shorten -- move to depth. x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1), tl.Dense(d_model), tl.Dup(), # Stack has (short_x, short_x, x) tl.ReversibleSerial(decoder_blocks), tl.Parallel([], tl.Drop()), tl.LayerNorm(), BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter tl.Dense(shorten_factor * d_embedding), tl.Fn( lambda x: np.reshape( # Prolong back. x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1), tl.Concatenate(), # Concatenate with just the embeddings. tl.CausalConv(d_embedding), tl.Relu(), tl.SRU(d_embedding), # One RNN layer for conditional dependence. tl.Dense(vocab_size), tl.LogSoftmax())
def test_default_name(self): layer = tl.Parallel(tl.Dup(), tl.Dup()) self.assertIn('Parallel', str(layer))
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads, n_attention_chunks, attention_type, dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size, mode): """Reversible transformer decoder layer. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_heads: int: number of attention heads n_attention_chunks: int: number of chunks for attention attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) share_qk: string, whether to share queries and keys ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks mode: str: 'train' or 'eval' Returns: the layer. """ if share_qk: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), tl.Dup(), ] else: pre_attention = [ Chunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter tl.LayerNorm(), tl.Dup(), tl.Dup(), tl.Parallel( tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key), tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value), ), ] attention = attention_type(mode=mode) # ReversibleAttentionHalfResidual requires that post_attention be linear in # its input (so the backward pass can be computed without knowing the input) post_attention = [ tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model), Unchunk(n_sections=n_attention_chunks), # pylint: disable=no-value-for-parameter BroadcastedDropout(rate=dropout, mode=mode), # pylint: disable=no-value-for-parameter ] if ff_use_sru: feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)] else: feed_forward = [ ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, ff_chunk_size, mode) ] return [ ReversibleAttentionHalfResidual(pre_attention, attention, post_attention), tl.ReversibleSwap(), ReversibleHalfResidual(feed_forward), tl.ReversibleSwap(), ]
def test_custom_name(self): layer = tl.Parallel(tl.Dup(), tl.Dup(), name='DupDup') self.assertIn('DupDup', str(layer))
def _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls): """Internal funnel block. Returns a list of layers implementing it. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper). Returns: A list of layers that maps (activations, mask) to (activations', mask). """ pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls) mask_pooling = MaskPool(pool_size, strides, separate_cls) attention = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) hidden_dropout = tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # h, mask tl.LayerNorm(), # h, mask tl.Branch(pooling, None), # h', h, mask tl.Residual( tl.Select([0, 1, 1, 2]), # h', h, h, mask attention, # attn, mask tl.Parallel(None, mask_pooling), # attn, mask' hidden_dropout # attn, mask' ), # funnel_activations, mask' tl.Residual( tl.LayerNorm(), feed_forward, hidden_dropout, ) ]
def test_weights(self): model = tl.Parallel(tl.Dense(3), tl.Dense(5)) self.assertIsInstance(model.weights, tuple) self.assertLen(model.weights, 2)
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEmbedder(vocab_size): # tokens --> vectors return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] def EncoderBlocks(n_blocks): # vectors masks --> vectors masks return [ _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_blocks) ] def EncoderDecoderBlocks(n_blocks): # vectors masks --> vectors masks return [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_blocks) ] in_embed = PositionalEmbedder(input_vocab_size) out_embed = (in_embed if output_vocab_size is None else PositionalEmbedder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch(in_embed, tl.PaddingMask()), # vec_e masks ..... ..... EncoderBlocks(n_encoder_layers), # vec_d masks ..... ..... tl.LayerNorm(), # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(), # tok_d ..... ..... ..... out_embed, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... EncoderDecoderBlocks(n_decoder_layers), # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Parallel([], tl.Drop(), tl.Drop()), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def test_shared_weights_nested(self): layer = tl.Dense(5) model = tl.Parallel([layer, tl.Dense(2)], [layer, tl.Dense(2)]) sample_input = (np.array([1, 2, 3, 4, 5]), np.array([1, 2, 3, 4, 5])) weights, _ = model.init(shapes.signature(sample_input)) self.assertIs(weights[1][0], tl.GET_WEIGHTS_FROM_CACHE)
def LatentTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [ _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] compress_seq = tl.Serial( # input: # tok tl.Branch([], tl.PaddingMask()), # tok mask encoder, # vec mask PickFirst(), # vec_f mask tl.Select([0], n_in=2)) # vec_f latent_transition = tl.Serial( tl.Parallel([tl.Dense(d_model), tl.Relu()], [tl.Dense(d_model), tl.Relu()]), tl.Add(), tl.Residual( tl.LayerNorm(), tl.Dense(d_model), tl.Relu(), tl.Dropout(rate=dropout, mode=mode), tl.Dense(d_model), )) pred_valid = tl.Serial(tl.Dense(2), Squeeze(1)) embed_tgt = tl.Serial( # Input # tok_d DropLast(mode=mode), # stok_d out_encoder, # svec_d ) decode_seq = tl.Serial( # Input: # vec_e tok_d tl.Select([1, 0, 1]), # tok_d vec_e tok_d tl.Parallel(embed_tgt, [], DropFirst()), # svec_d vec_e tok_d' ConcatDeEntoEnDe(), # vec_ed tok_d' # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_d' tl.LayerNorm(), # vec_ed tok_d' DropFirst(), # vec_d tok_d' # Map to output vocab. tl.Dense(output_vocab_size), # pred_d tok_d' ) # compress_seq: n_in 1 n_out 1: add mask, encode, pick last hidden # latent_transition: n_in 2 n_out 1: s, a -> s_1 # pred_valid: n_in 1 n_out 1: s_1 -> pred_v # decode_seq: n_in 2 n_out 2: copy target, shift right, decode, output return tl.Serial( # 0 1 2 3 4 5 6 7 8 # Input: # tok_s tok_a tok_s1 r v tl.Select([0, 1, 2, 0, 1, 3, 4]), # tok_s tok_a tok_s1 tok_s tok_a r v # Encode. tl.Parallel( compress_seq, compress_seq), # vec_s vec_a tok_s1 tok_s tok_a r v tl.Branch(latent_transition, [], tl.Select( [1], n_in=2)), # vec_s1 vec_s vec_a tok_s1 tok_s tok_a r v tl.Branch(pred_valid, []), # pred_v vec_s1 vec_s vec_a tok_s1 tok_s tok_a r v # Decode. tl.Select([1, 4, 2, 5, 3, 6, 0, 8, 7]), # vec_s1 tok_s1 vec_s tok_s vec_a tok_a pred_v v r tl.Parallel(decode_seq, decode_seq, decode_seq ), # pred_s1 tok_s1 pred_s tok_s pred_a tok_a pred_v v r )
def EncoderBlock(d_model, d_ff, n_heads, attention_type, dropout, ff_activation, ff_dropout, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0, attention_chunk_size=0, center_layernorm=True, use_bfloat16=False, use_two_swaps_per_block=True, mode='train'): """Returns a list of layers that implements a Reformer encoder block. The input to the layer is a pair, (activations, mask), where the mask was created from the original source tokens to prevent attending to the padding part of the input. Args: d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_heads: int: number of attention heads attention_type: subclass of tl.BaseCausalAttention: attention class to use dropout: float: dropout rate (how much to drop out) ff_activation: the non-linearity in feed-forward layer ff_dropout: the dropout rate in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity attention_chunk_size: int, if > 0 run attention chunked at this size center_layernorm: whether to use centering in LayerNorm (default) or if to skip it, which is known as RMS normalization. use_bfloat16: whether to use bfloat16 for weights (default: False) use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder block, otherwise use only one swap. mode: str: 'train' or 'eval' Returns: A list of layers that maps (activations, mask) to (activations, mask). """ if mode == 'predict': # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. mode = 'eval' def _Attn(): return ct.ApplyAttentionLayer( attention_type=attention_type, d_model=d_model, n_heads=n_heads, d_qk=d_model // n_heads, d_v=d_model // n_heads, masked=True, causal=False, attention_dropout=dropout, output_dropout=dropout, attention_chunk_size=attention_chunk_size, mode=mode) def _FF(): return ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2], ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, center_layernorm, mode, use_bfloat16) # TODO(lukaszkaiser): refactor efficient attention layers to unify the API # If we're using standard attention, we need to pass reshaped mask and not # return the mask to be compatible with the EfficientAttention API. attention = _Attn() if attention.n_out == 2: attention = tl.Serial(tl.Parallel([], _InsertAxes12()), attention, tl.Select([0], n_in=2)) def _attention_half_residual(): return [ tl.ReversibleHalfResidual( tl.LayerNorm(center=center_layernorm), attention_layer=attention, name='ReversibleHalfResidualEncoderAttn'), tl.ReversibleSwap() ] def _feed_forward(): layers = [ tl.ReversibleHalfResidual(_FF(), name='ReversibleHalfResidualEncoderFF') ] if use_two_swaps_per_block: layers.append(tl.ReversibleSwap()) return layers return _attention_half_residual() + _feed_forward()