def TransformerEncoder(vocab_size=vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, EncoderBlock=EncoderBlock): """ Returns a Transformer encoder model. The input to the model is a tensor of tokens. Args: vocab_size (int): vocab size. Defaults to vocab_size. n_classes (int): how many classes on output. Defaults to 10. d_model (int): depth of embedding. Defaults to 512. d_ff (int): depth of feed-forward layer. Defaults to 2048. n_layers (int): number of encoder/decoder layers. Defaults to 6. n_heads (int): number of attention heads. Defaults to 8. dropout (float): dropout rate (how much to drop out). Defaults to 0.1. dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None. max_len (int): maximum symbol length for positional encoding. Defaults to 2048. mode (str): 'train' or 'eval'. Defaults to 'train'. ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu. EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock. Returns: trax.layers.combinators.Serial: A Transformer model as a layer that maps from a tensor of tokens to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ### # Use the function `EncoderBlock` (implemented above) and pass in the parameters over `n_layers` encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, FeedForwardBlock=FeedForwardBlock) for _ in range(n_layers) ] # Assemble and return the model. return tl.Serial( # Encode tl.Branch( # Use `positional_encoder` positional_encoder, # Use trax padding mask tl.PaddingMask(), ), # Use `encoder_blocks` encoder_blocks, # Use select layer tl.Select([0], n_in=2), # Use trax layer normalization tl.LayerNorm(), # Map to output categories. # Use trax mean. set axis to 1 tl.Mean(axis=1), # Use trax Dense using `n_classes` tl.Dense(n_classes), # Use trax log softmax tl.LogSoftmax(), )
def ReformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, d_attention_key=64, d_attention_value=64, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, encoder_attention_type=tl.SelfAttention, encoder_decoder_attention_type=tl.SelfAttention, axial_pos_shape=(), d_axial_pos_embs=None, ff_activation=tl.Relu, ff_use_sru=0, ff_chunk_size=0, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: source, target. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer d_attention_key: int: depth of key vector for each attention head d_attention_value: int: depth of value vector for each attention head n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding encoder_attention_type: class: attention class to use, such as SelfAttention encoder_decoder_attention_type: class: attention class to use, such as SelfAttention axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_activation: the non-linearity in feed-forward layer ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors if not axial_pos_shape: positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) else: assert d_axial_pos_embs is not None positional_encoding = tl.AxialPositionalEncoding( shape=axial_pos_shape, d_embs=d_axial_pos_embs, dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)), dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, encoder_attention_type, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ # tok_e mask_e tok_e tok_d tok_d in_encoder, # vec_e mask_e tok_e tok_d tok_d tl.Dup(), # vec_e1 vec_e2 mask_e tok_e tok_d tok_d tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [] if isinstance(encoder_decoder_attention_type, (tuple, list)): assert n_decoder_layers % len(encoder_decoder_attention_type) == 0 else: encoder_decoder_attention_type = [encoder_decoder_attention_type] for layer_idx in range(n_decoder_layers): layer_attention_type = encoder_decoder_attention_type[ layer_idx % len(encoder_decoder_attention_type)] decoder_block = DecoderBlock( d_model, d_ff, d_attention_key, d_attention_value, n_heads, attention_type=layer_attention_type, dropout=dropout, ff_activation=ff_activation, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, mode=mode) decoder_blocks.append(decoder_block) # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask_e tok_e tok_d tok_d # Encode. encoder, # vec_e mask_e tok_e tok_d tok_d # Decode. tl.Select([3, 0, 1, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder, given their masks. tl.Select([2, 0, 3, 1]), # svec_d mask_d vec_e mask_e tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Run (encoder and) decoder blocks. tl.Dup(), # vec_ed1 vec_ed2 tok_e tok_d tl.ReversibleSerial(decoder_blocks), # vec_ed1 vec_ed2 tok_e tok_d tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def test_noop_dup(self): layer = tl.Branch([], tl.Dup()) x = np.array([1, 2, 3]) ys = layer(x) self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]])
def test_one_sublayer(self): layer = tl.Branch(DivideBy(0.5)) x = np.array([1, 2, 3]) ys = layer(x) self.assertEqual(as_list(ys), [2, 4, 6])
def test_printing_sublayers(self): layer = tl.Branch(tl.Add(), tl.Add()) expected_result = 'Branch_in2_out2[\n Add_in2\n Add_in2\n]' self.assertEqual(expected_result, str(layer))
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, id_to_mask=None, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, self._n_devices, rng = (self._init_host_and_devices( n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._has_weights = has_weights self._id_to_mask = id_to_mask self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask) # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) # If the inputs are a tuple/list, add [None] (batch) to each element. if self._inputs.input_shape and isinstance(self._inputs.input_shape[0], (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.input_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(self._inputs.input_shape)) # Same for targets. if self._inputs.target_shape and isinstance( self._inputs.target_shape[0], (list, tuple)): model_target_shape = tuple( tuple([None] + list(shape)) for shape in self._inputs.target_shape) else: model_target_shape = tuple([None] + list(self._inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = math.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = math.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(shape_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. shapes, dtypes = shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = math.jit( new_opt_state_and_model_state, static_argnums=(0, )) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda self._inputs.example_shape_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [ self._metrics_dict[m](has_weights=self._has_weights, id_to_mask=self._id_to_mask) for m in self._metrics ] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def TransformerNoEncDecAttention(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len), ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers)] encoder = tl.Serial( in_encoder, encoder_blocks, tl.LayerNorm() ) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [ _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e mask_e tok_e tok_d tok_d encoder, # vec_e mask_e tok_e tok_d tok_d # Simple encoder mask, doesn't contain extra dims. tl.Select([2, 0, 2], n_in=3), # tok_e vec_e tok_e tok_d tok_d _MaskOfRightShiftedArray( n_shifts=0), # mask_e vec_e tok_e tok_d tok_d # Decode. tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d tl.Branch( [], _MaskOfRightShiftedArray() ), # stok_d mask_d vec_e mask_e tok_e tok_d out_encoder, # svec_d mask_d vec_e mask_e tok_e tok_d # Concat encoder and decoder. tl.Select([2, 0, 3, 1]), # vec_e svec_d mask_e mask_d tok_e tok_d _ConcatWithPadding(), # vec_ed tok_e tok_d # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d _StripFromConcatenateWithPadding(), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def FunnelTransformer(vocab_size, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_decoder_blocks=2, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), separate_cls=True): """Returns a Full Funnel Transformer, that can be used for example for BERT. This model outputs token-level categorical distributions over all vocab: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions over `vocab_size` categories for each token; shape is (batch_size, sequence_length, vocab_size). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_decoder_blocks: Number of transformer blocks in the upsampling decoder. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] n_encoder_segments = len(encoder_segment_lengths) encoder_blocks_before_first_pooling = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(encoder_segment_lengths[0])] encoder_blocks_from_first_pooling = [] for i in range(1, n_encoder_segments): # Building i'th segment # Add funnel block between segments encoder_blocks_from_first_pooling.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size=pool_size, strides=pool_size, separate_cls=separate_cls)) for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks_from_first_pooling.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_decoder_blocks)] total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1) # Assemble and return the model. return tl.Serial( # toks tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks_before_first_pooling, # vecs masks tl.Select([0, 1, 0, 1]), # vecs masks residual = vecs old_masks encoder_blocks_from_first_pooling, # vecs masks residual masks tl.Select([0, 2, 3]), # vecs residual masks tl.Parallel( # residual from first segment is taken before # normalization, so apply it now None, tl.LayerNorm(), None), # vecs norm(residual) masks _Upsampler(total_pool_size, separate_cls), # vecs masks decoder_blocks, tl.Select([0], n_in=2), # vecs tl.LayerNorm(), tl.Dense(vocab_size), )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train', axial_pos_shape=None, d_axial_pos_embs=None, ff_use_sru=0, ff_chunk_size=0, ff_sparsity=0): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, [-2], # dropout_shared_axes max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode=mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode, ff_use_sru=ff_use_sru, ff_chunk_size=ff_chunk_size, ff_sparsity=ff_sparsity) for _ in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [ tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1) ]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls): """Internal funnel block. Returns a list of layers implementing it. The input is an activation tensor. Args: d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each block. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within a block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper). Returns: A list of layers that maps (activations, mask) to (activations', mask). """ pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls) mask_pooling = MaskPool(pool_size, strides, separate_cls) attention = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout, mode=mode) hidden_dropout = tl.Dropout( rate=dropout, shared_axes=dropout_shared_axes, mode=mode) feed_forward = _FeedForwardBlock( d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation) return [ # h, mask tl.LayerNorm(), # h, mask tl.Branch(pooling, None), # h', h, mask tl.Residual( tl.Select([0, 1, 1, 2]), # h', h, h, mask attention, # attn, mask tl.Parallel(None, mask_pooling), # attn, mask' hidden_dropout # attn, mask' ), # funnel_activations, mask' tl.Residual( tl.LayerNorm(), feed_forward, hidden_dropout, ) ]
def FunnelTransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, encoder_segment_lengths=(2, 2, 2), n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, pool_layer=tl.AvgPool, pool_size=(2,), strides=(2,), separate_cls=True): """Returns a Funnel Encoder. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. encoder_segment_lengths: Tuple, where each element denotes the number of transformer encoder blocks preceding a funnel transformer block. There is no funnel block after the last sequence of encoder blocks, therefore the total number of blocks in the model is equal to `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. pool_layer: Type of pooling layer used for downsampling in each of the funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`. pool_size: Shape of window that gets reduced to a single vector value. If the layer inputs are :math:`n`-dimensional arrays, then `pool_size` must be a tuple of length :math:`n-2`. strides: Offsets from the location of one window to the locations of neighboring windows along each axis. If specified, must be a tuple of the same length as `pool_size`. If None, then offsets of 1 along each window axis, :math:`(1, ..., 1)`, will be used. separate_cls: If `True`, pooling in funnel blocks is not applied to embeddings of the first token (`cls` from BERT paper) and only final embedding of this token is used for categorization - the rest are discarded. If `False`, each token from the beginning is pooled and all embeddings are averaged and mapped to output categories like in original `TransformerEncoder` model. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ assert encoder_segment_lengths positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] encoder_blocks = [] n_encoder_segments = len(encoder_segment_lengths) for i in range(n_encoder_segments): # Building i'th segment for _ in range(encoder_segment_lengths[i]): # Create segment_size encoder blocks encoder_blocks.append( _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation)) # If not last segment, add funnel block if i != n_encoder_segments - 1: encoder_blocks.append( _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, pool_layer, pool_size, strides, separate_cls)) cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1) # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. cls_pooling, # cls tl.Dense(n_classes), # cls )
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, has_weights=False, nontrainable_param_map=None, mask_id=None, metrics=None): self._is_chief, self._n_devices, rng = ( self._init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries self._has_weights = has_weights self._mask_id = mask_id self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id) inputs = inputs(self._n_devices) self._inputs = inputs # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) first_shape = inputs.input_shape[0] # If the inputs are a tuple/list, add [None] (batch) to each element. if isinstance(first_shape, (list, tuple)): model_input_shape = tuple( tuple([None] + list(shape)) for shape in inputs.input_shape) model_target_shape = tuple( tuple([None] + list(shape)) for shape in inputs.target_shape) else: # Otherwise just add [None] to the input shape. model_input_shape = tuple([None] + list(inputs.input_shape)) model_target_shape = tuple([None] + list(inputs.target_shape)) # Change all None to 1 in input and target shape. model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape) model_target_shape = backend.nested_map(lambda x: x or 1, model_target_shape) def new_opt_state_and_model_state(input_shape, input_dtype, target_shape, target_dtype, rng): """Returns optimizer and model states suitable for training a model.""" # Combine inputs and targets on the stack. if not isinstance(input_dtype, (list, tuple)): input_dtype = [input_dtype] input_shape = [input_shape] if not isinstance(target_dtype, (list, tuple)): target_dtype = [target_dtype] target_shape = [target_shape] dtypes = list(input_dtype) + list(target_dtype) shapes = list(input_shape) + list(target_shape) if self._has_weights: shapes += list(target_shape) dtypes += [np.float32 for _ in target_dtype] input_signature = tuple(ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) # We need to create a new model instance and not reuse `model_train` here, # because `m.initialize` puts cached parameter values in `m` and hence the # next call of `m.initialize` will give wrong results. m = tl.Serial(model(mode='train'), loss_fn) m._set_rng_recursive(rng) # pylint: disable=protected-access weights, state = m.init(input_signature) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if _is_jit_init(): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state, static_argnums=(0, 1, 2, 3)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state( # pylint: disable=g-long-lambda model_input_shape, self._inputs.input_dtype, model_target_shape, self._inputs.target_dtype, init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m](has_weights=self._has_weights, mask_id=self._mask_id) for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) # TODO(lukaszkaiser): clean this up once layer API stabilizes. # For now, we need to initialize metric layers somehow, so here we go. # We assume that they do not have any parameters, so this is a dummy. dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,)) dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes) metrics_in_parallel._set_rng_recursive(init_rng) # pylint: disable=protected-access m_weights, m_state = metrics_in_parallel.init(dummy_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn( model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn( model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn # TODO(pkozakowski): "Learning rate schedules" are currently able to control # control all optimizer parameters and model state, so let's rename them # accordingly. self._lr_schedule = lr_schedule if nontrainable_param_map is None: nontrainable_param_map = {} self._nontrainable_param_map = nontrainable_param_map # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._lr_fn = None self._opt_state = None self._step = None self._model_state = None if output_dir is not None: self.reset(output_dir)
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ def Embedder(vocab_size): # tokens --> vectors return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), ] in_embedder = Embedder(input_vocab_size) out_embedder = (in_embedder if output_vocab_size is None else Embedder(output_vocab_size)) # Positional encodings are not shared between encoder and decoder. # Since encoder doesn't run stepwise, we do not use predict mode there. encoder_mode = 'eval' if mode == 'predict' else mode in_encoder = in_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=encoder_mode) ] out_encoder = out_embedder + [ tl.PositionalEncoding(max_len=max_len, mode=mode) ] if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers) ] encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu): """Returns a Transformer encoder merged with an N-way categorization head. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len) ] encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers) ] # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch(positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs )
def ConfigurableTransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, attention_type=tl.Attention, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer encoder merged with an N-way categorization head. This model performs text categorization: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 2 tensor representing a batch of log-probability distributions over N categories; shape is (batch_size, `n_classes`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Final dimension of the output tensors, representing N-way classification. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size attention_type: The attention layer to use for the encoder part. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model that maps strings (conveyed via token IDs) to probability-like activations over a range of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), PositionalEncoder( mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs) ] # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, attention_type) for i in range(n_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax # Transformer, but may not be the right option in general. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=0.0, mode=mode) return [ tl.Embedding(d_model, vocab_size), # TODO(kitaev): BroadcastedDropout? tl.Dropout(rate=dropout, mode=mode), positional_encoding, ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_encoder_layers)] encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch( in_encoder, [tl.PaddingMask(), tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), # vec_e mask tok_d ..... tl.Dup(), # vec_e1 vec_e2 mask tok_d ..... tl.ReversibleSerial(encoder_blocks), # vec_e1 vec_e2 mask tok_d ..... # The two sets of activations need to be reduced to one, in this case by # averaging them. Note that ReformerLM concatenates instead. Various # options (concat, average, add, keep only one, etc.) seem to perform # similarly. We don't concatenate here because we want exact parameter # parity with the standard Transformer. tl.Fn(lambda x, y: (x+y)/2.0), # vec_e mask tok_d ..... tl.LayerNorm(), # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn(lambda x, y: (x+y)/2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def ConfigurableTransformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, max_len=2048, dropout=0.1, dropout_shared_axes=None, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, encoder_attention_type=tl.Attention, encoder_decoder_attention_type=tl.CausalAttention, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(input_vocab_size)`, and `0` values mark padding positions. - target: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(output_vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if None, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Final dimension of tensors at most points in the model, including the initial embedding output. d_ff: Size of special dense layer in the feed-forward part of each encoder and decoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder/decoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of `Layer`. ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers in addition to the feed-forward block (second int specifies sru size) ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. encoder_decoder_attention_type: The attention layer to use for the encoder-decoder attention. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs) ) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension encoder_decoder_blocks = [ EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_decoder_attention_type) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks ..... ..... encoder, # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(mode=mode), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch( [], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... )
def RNNLM(vocab_size, d_model=512, n_layers=2, rnn_cell=tl.LSTMCell, rnn_cell_d_state_multiplier=2, dropout=0.1, mode='train'): """Returns an RNN language model. This model performs autoregressive language modeling: - input: rank 2 tensor representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length). The tensor elements are integers in `range(vocab_size)`, and `0` values mark padding positions. - output: rank 3 tensor representing a batch of log-probability distributions for each sequence position over possible token IDs; shape is (batch_size, sequence_length, `vocab_size`). Args: vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in `range(vocab_size)`. These integers typically represent token IDs from a vocabulary-based tokenizer. d_model: Embedding depth throughout the model. n_layers: Number of RNN layers. rnn_cell: Type of RNN cell; must be a subclass of `Layer`. rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell state. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout. mode: If `'predict'`, use fast inference; if `'train'` apply dropout. Returns: An RNN language model as a layer that maps from a tensor of tokens to activations over a vocab set. """ if n_layers != 2: # TODO(jonni): Remove n_layers arg, if it can't vary? raise ValueError(f'Number of layers must be set to 2; instead got' f' {n_layers}.') def MultiRNNCell(): """Multi-layer RNN cell.""" return tl.Serial( tl.Parallel([], tl.Split(n_items=n_layers)), tl.SerialWithSideOutputs( [rnn_cell(n_units=d_model) for _ in range(n_layers)]), tl.Parallel([], tl.Concatenate(n_items=n_layers))) zero_state = tl.MakeZeroState( # pylint: disable=no-value-for-parameter depth_multiplier=n_layers * rnn_cell_d_state_multiplier) return tl.Serial( tl.ShiftRight(mode=mode), tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, mode=mode), tl.Branch([], zero_state), tl.Scan(MultiRNNCell(), axis=1, mode=mode), tl.Select([0], n_in=2), # Drop RNN state. tl.Dense(vocab_size), )
def TransformerEncoder(vocab_size, n_classes=10, d_model=512, d_ff=2048, n_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer-style encoder. For each item in a batch, this model performs a sequence-to-sequence mapping: - input: sequence of integers, usually token id's from a fixed-size vocabulary -- integers in `range(M)`, where `M` is the vocabulary size. - output: same-length sequence of N-dimensional vectors, where each vector can be interpreted as a log-probability distribution over N discrete categories. Args: vocab_size: "Vocabulary size" -- input integer id's must be in `range(vocab_size)`. Id's typically come from preprocessing text data with a vocabulary-based tokenizer. n_classes: Size/depth of the output vectors, intended for an N-way classification task. d_model: The basic embedding size (vector depth) of the model. This is the vector size used by the initial embedding layer and at many intermediate points in the model. d_ff: Vector depth (typically greater than `d_model`) used in the feed-forward (`Dense`) layer of each encoder block. n_layers: Number of encoder blocks. Each encoder block includes attention, dropout, residual, feed-forward (`Dense`), and activation layers. n_heads: Number of attention heads. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within an encoder block. dropout_shared_axes: Tensor axes on which to share a dropout mask. max_len: Maximum symbol length for positional encoding. mode: If `'train'`, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: The activation function (layer) at the end of each encoder block. Returns: A Transformer model as a layer that maps from token id's to activations over a set of output classes. """ positional_encoder = [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode), tl.PositionalEncoding(max_len=max_len)] encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for i in range(n_layers)] # Assemble and return the model. return tl.Serial( # toks # Encode. tl.Branch( positional_encoder, tl.PaddingMask()), # vecs masks encoder_blocks, # vecs masks tl.Select([0], n_in=2), # vecs tl.LayerNorm(), # vecs # Map to output categories. tl.Mean(axis=1), # vecs tl.Dense(n_classes), # vecs tl.LogSoftmax(), # vecs )
def Transformer(input_vocab_size, output_vocab_size=None, d_model=D_MODEL, d_ff=D_FF, n_encoder_layers=N_LAYERS, n_decoder_layers=N_LAYERS, n_heads=N_HEADS, max_len=MAX_SEQUENCE_LENGTH, dropout=DROPOUT_RATE, dropout_shared_axes=DROPOUT_SHARED_AXES, mode=MODE, ff_activation=FF_ACTIVATION_TYPE): """Returns a full Transformer model. This model is an encoder-decoder that performs tokenized string-to-string ("source"-to-"target") transduction: - inputs (2): - source: Array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length), where sequence_length <= ``max_len``. Array elements are integers in ``range(input_vocab_size)``, and 0 values mark padding positions. - target: Array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length), where sequence_length <= ``max_len``. Array elements are integers in ``range(output_vocab_size)``, and 0 values mark padding positions. - output: 3-D array of raw activations with last/innermost dimension of ``output_vocab_size``, suitable for decoding into a batch of token strings; shape is (batch_size, sequence_length, ``vocab_size``). An example use would be to translate (tokenized) sentences from English to German. Args: input_vocab_size: Input vocabulary size -- each element of the input tensor should be an integer in ``range(vocab_size)``. These integers typically represent token IDs from a vocabulary-based tokenizer. output_vocab_size: If specified, gives the vocabulary size for the targets; if ``None``, then input and target integers (token IDs) are assumed to come from the same vocabulary. d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each encoder block. n_encoder_layers: Number of encoder blocks. n_decoder_layers: Number of decoder blocks. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within encoder/decoder blocks. The same rate is also used for attention dropout in encoder/decoder blocks. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If ``'predict'``, use fast inference. If ``'train'``, each encoder/decoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder/decoder block; must be an activation-type subclass of :py:class:`Layer`. Returns: A Transformer model as a layer that maps from a source-target tokenized text pair to activations over a vocab set. """ # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise. encoder_mode = 'eval' if mode == 'predict' else mode # Share embedding weights if no separate output vocab size. in_embedder = tl.Embedding(input_vocab_size, d_model) if output_vocab_size is None: out_embedder = in_embedder output_vocab_size = input_vocab_size else: out_embedder = tl.Embedding(output_vocab_size, d_model) def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) def _EncBlock(): return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) def _Encoder(): encoder = tl.Serial( in_embedder, _Dropout(), tl.PositionalEncoding(max_len=max_len, mode=encoder_mode), [_EncBlock() for _ in range(n_encoder_layers)], tl.LayerNorm(), ) return tl.Cache(encoder) if mode == 'predict' else encoder def _EncDecBlock(): return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e # Model output is decoder-side vectors and decoder-side tokens: vec_d tok_d return tl.Serial( tl.Select([0, 1, 1]), # Copies decoder tokens for use in loss. # Encode. tl.Branch([], tl.PaddingMask()), # tok_e masks tok_d tok_d _Encoder(), # Decode. tl.Select([2, 1, 0]), # Re-orders inputs: tok_d masks vec_e ..... tl.ShiftRight(mode=mode), out_embedder, _Dropout(), tl.PositionalEncoding(max_len=max_len, mode=mode), tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... [_EncDecBlock() for _ in range(n_decoder_layers)], tl.LayerNorm(), tl.Select([0], n_in=3), # Drops masks and encoding vectors. # Map vectors to match output vocab size. tl.Dense(output_vocab_size), )
def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs, output_dir=None, random_seed=None, n_devices=None, checkpoints_at=None, should_save_checkpoints=True, should_write_summaries=True, metrics=None, checkpoint_highest=None, checkpoint_lowest=None): self._is_chief, _, self._n_devices, rng = ( training.init_host_and_devices(n_devices, random_seed)) self._should_save_checkpoints = should_save_checkpoints and self._is_chief self._checkpoints_at = checkpoints_at or [] self._should_write_summaries = should_write_summaries if not output_dir: self._should_save_checkpoints = False self._should_write_summaries = False self._checkpoint_highest = checkpoint_highest self._checkpoint_lowest = checkpoint_lowest self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS # Inputs is either an Inputs instance or a function that returns it. self._inputs = inputs if callable( inputs): # If we pass a function, e.g., through gin, call it. self._inputs = inputs() # Initialize the learning rate to a dummy value. It will be set in reset(). opt = optimizer(learning_rate=0.0) # Setup the model. model_train = model(mode='train') model_predict_eval = model(mode='eval') self._model_with_loss = tl.Serial(model_train, loss_fn) # Setup state. rng, init_rng = jax_random.split(rng) self._rngs = np.stack(jax_random.split(rng, self._n_devices)) shapes, dtypes = self._inputs.example_shape_dtype input_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes)) def new_opt_state_and_model_state(rng): """Returns optimizer and model states suitable for training a model.""" weights, state = self._model_with_loss.init(input_signature, rng=rng) (slots, opt_params) = opt.tree_init(weights) return (OptState(weights, slots, opt_params), state) if fastmath.is_backend(fastmath.Backend.JAX): # JIT parameter initialization to avoid memory fragmentation new_opt_state_and_model_state = ( fastmath.jit(new_opt_state_and_model_state)) self._new_opt_state_and_model_state = ( lambda: new_opt_state_and_model_state(init_rng)) # Arrange and initialize metrics layers. self._metrics = list(sorted(self._metrics_dict.keys())) metrics_layers = [self._metrics_dict[m] for m in self._metrics] metrics_in_parallel = tl.Branch(*metrics_layers) metrics_in_parallel.rng = init_rng example_signature = tuple( ShapeDtype(s, d) for (s, d) in zip(*self._inputs.example_shape_dtype)) model_predict_eval.init(example_signature) self._input_signature = example_signature output_signature = model_predict_eval.output_signature( example_signature) m_weights, m_state = metrics_in_parallel.init(output_signature) self._metrics_weights = self._for_n_devices(m_weights) self._metrics_state = self._for_n_devices(m_state) # Jit model_predict and update so they're fast. self._jit_eval = _jit_predict_fn(model_predict_eval, metrics_in_parallel, self._n_devices) self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt, self._n_devices) self._model_train = model_train self._model_predict_eval = model_predict_eval self._loss_fn = loss_fn self._lr_schedule = lr_schedule # Those fields will be set in reset(). self._output_dir = None self._train_sw = None self._eval_sw = None self._history = None self._opt_state = None self._step = None self._model_state = None self.reset(output_dir)
def TransformerEncoder(vocab_size, n_classes=10, d_model=D_MODEL, d_ff=D_FF, n_layers=N_LAYERS, n_heads=N_HEADS, max_len=MAX_SEQUENCE_LENGTH, dropout=DROPOUT_RATE, dropout_shared_axes=DROPOUT_SHARED_AXES, mode=MODE, ff_activation=FF_ACTIVATION_TYPE): """Returns a Transformer encoder suitable for N-way classification. This model maps tokenized text to N-way (``n_classes``) activations: - input: Array representing a batch of text strings via token IDs plus padding markers; shape is (batch_size, sequence_length), where sequence_length <= ``max_len``. Array elements are integers in ``range(vocab_size)``, and 0 values mark padding positions. - output: Array representing a batch of raw (non-normalized) activations over ``n_classes`` categories; shape is (batch_size, ``n_classes``). Args: vocab_size: Input vocabulary size -- each element of the input array should be an integer in ``range(vocab_size)``. These integers typically represent token IDs from a vocabulary-based tokenizer. n_classes: Last/innermost dimension of output arrays, suitable for N-way classification. d_model: Last/innermost dimension of activation arrays at most points in the model, including the initial embedding output. d_ff: Last/innermost dimension of special (typically wider) :py:class:`Dense` layer in the feedforward part of each encoder block. n_layers: Number of encoder blocks. Each block includes attention, dropout, residual, layer-norm, feedforward (:py:class:`Dense`), and activation layers. n_heads: Number of attention heads. max_len: Maximum symbol length for positional encoding. dropout: Stochastic rate (probability) for dropping an activation value when applying dropout within encoder blocks. The same rate is also used for attention dropout in encoder blocks. dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``) is a useful way to save memory and apply consistent masks to activation vectors at different sequence positions. mode: If ``'train'``, each encoder block will include dropout; else, it will pass all values through unaltered. ff_activation: Type of activation function at the end of each encoder block; must be an activation-type subclass of :py:class:`Layer`. Returns: A Transformer model that maps strings (conveyed by token IDs) to raw (non-normalized) activations over a range of output classes. """ def _Dropout(): return tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode) def _EncBlock(): return _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) return tl.Serial( tl.Branch([], tl.PaddingMask()), # Creates masks from copy of the tokens. tl.Embedding(vocab_size, d_model), _Dropout(), tl.PositionalEncoding(max_len=max_len), [_EncBlock() for _ in range(n_layers)], tl.Select([0], n_in=2), # Drops the masks. tl.LayerNorm(), tl.Mean(axis=1), tl.Dense(n_classes), )
def test_default_name(self): layer = tl.Branch(tl.Add(), DivideBy(0.5)) self.assertIn('Branch', str(layer))
def Transformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings( input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs) ) encoder_blocks = [ transformer._EncoderBlock(d_model, d_ff, n_heads, dropout, # pylint: disable=protected-access dropout_shared_axes, mode, ff_activation) for i in range(n_encoder_layers)] encoder = tl.Serial( in_encoder, encoder_blocks, tl.LayerNorm() ) if mode == 'predict': encoder = tl.Cache(encoder) decoder_blocks = [ transformer._DecoderBlock(d_model, d_ff, n_heads, dropout, # pylint: disable=protected-access dropout_shared_axes, mode, ff_activation) for i in range(n_decoder_layers)] # pylint: disable=protected-access # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e mask_e tok_e tok_d tok_d encoder, # vec_e mask_e tok_e tok_d tok_d # Simple encoder mask, doesn't contain extra dims. tl.Select([2, 0, 2], n_in=3), # tok_e vec_e tok_e tok_d tok_d tl.Fn('EncoderMask', # mask_e vec_e tok_e tok_d tok_d lambda x: x != 0, n_out=1), # Decode. tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d out_encoder, # svec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder. tl.Select([1, 0]), # vec_e svec_d mask_e tok_e tok_d ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def Reformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, ff_activation=tl.Relu, ff_dropout=None, mode='train'): """Reversible transformer encoder-decoder model. This model expects an input pair: target, source. At the moment, this model supports dot-product attention only. For the attention types in the Reformer paper, see ReformerLM. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding ff_activation: the non-linearity in feed-forward layer ff_dropout: float: (optional) separate dropout rate at feed-forward nonlinearity. This is called relu_dropout in T2T. mode: str: 'train' or 'eval' Returns: A Reformer model as a layer that maps from a target, source pair to activations over a vocab set. """ # The current API for custom gradients assumes that a layer must be # differentiable wrt all of its inputs, but the Transformer puts bool-dtype # masks on the stack. This causes jax to error, even though the so-called # "gradient" wrt the masks is never actually computed. # TODO(kitaev): remove this hack. if fastmath.backend_name() == 'jax': jax.api._check_inexact_input_vjp = lambda x: None # pylint: disable=protected-access def PositionalEncoder(vocab_size, mode): # tokens --> vectors # TODO(kitaev): axial positional encoding is better for very long sequences. positional_encoding = tl.PositionalEncoding( max_len=max_len, dropout=dropout, mode=mode) return [ tl.Embedding(vocab_size, d_model), tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode), positional_encoding, ] # TODO(kitaev): The regular trax Transformer shares vocab embeddings and # position embeddings between the encoder and decoder if output_vocab_size is # None. This isn't supported here because (a) Trax shares weights by sharing # layer instances, but we need two separate instances to have mode == 'eval' # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does # not work if its sublayers participate in any weight sharing. # Mode 'predict' means that the decoder should be run one token at a time. # The encoder only ever runs over full sequences, which is why it's switched # to 'eval' mode instead. in_encoder = PositionalEncoder( input_vocab_size, mode='eval' if mode == 'predict' else mode) if output_vocab_size is None: output_vocab_size = input_vocab_size out_encoder = PositionalEncoder(output_vocab_size, mode) # pylint: disable=g-complex-comprehension encoder_blocks = [ EncoderBlock( d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation, ff_dropout, mode) for _ in range(n_encoder_layers)] # pylint: enable=g-complex-comprehension encoder = tl.Serial([ in_encoder, tl.Dup(), tl.ReversibleSerial(encoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), tl.LayerNorm(), ]) if mode == 'predict': encoder = tl.Cache(encoder) encoder_decoder_blocks = [ EncoderDecoderBlock( d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode) for _ in range(n_decoder_layers)] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d tl.Branch([], [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]), # # tok_e mask tok_d ..... # Encode. encoder, # vec_e mask tok_d ..... # Decode. tl.Select([2, 0, 1]), # tok_d vec_e mask ..... tl.ShiftRight(mode=mode), # tok_d vec_e mask ..... out_encoder, # vec_d vec_e mask ..... tl.Dup(), # vec_d1 vec_d2 vec_e mask ..... tl.ReversibleSerial(encoder_decoder_blocks), tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0), # vec_d vec_e mask ..... tl.LayerNorm(), # vec_d vec_e mask ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d ..... tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def BERTPretrainingHead(n_classes): return tl.Branch(BERTClassifierHead(n_classes), BERTMLMHead()),
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ def PositionalEncoder(vocab_size): # tokens --> vectors return [ tl.Embedding(d_model, vocab_size), tl.Dropout(rate=dropout, mode=mode), tl.PositionalEncoding(max_len=max_len), ] in_encoder = PositionalEncoder(input_vocab_size) out_encoder = (in_encoder if output_vocab_size is None else PositionalEncoder(output_vocab_size)) if output_vocab_size is None: output_vocab_size = input_vocab_size encoder_blocks = [ _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_encoder_layers) ] encoder_decoder_blocks = [ _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_decoder_layers) ] # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 1, 1]), # tok_e tok_d tok_d # Encode. tl.Branch(in_encoder, tl.PaddingMask()), # vec_e masks ..... ..... encoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_e ..... ..... ..... # Decode. tl.Select([2, 1, 0]), # tok_d masks vec_e ..... tl.ShiftRight(), # tok_d ..... ..... ..... out_encoder, # vec_d ..... ..... ..... tl.Branch([], tl.EncoderDecoderMask()), # vec_d masks ..... ..... encoder_decoder_blocks, # vec_d masks ..... ..... tl.LayerNorm(), # vec_d ..... ..... ..... # Map to output vocab. tl.Select([0], n_in=3), # vec_d tok_d tl.Dense(output_vocab_size), # vec_d ..... tl.LogSoftmax(), # vec_d ..... )
def Transformer2(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, dropout_shared_axes=None, max_len=2048, mode='train', ff_activation=tl.Relu, ff_dropout=0.1, ff_chunk_size=0, ff_use_sru=0, ff_sparsity=0, ff_sparsity_type='1inN', attention_chunk_size=0, encoder_attention_type=tl.Attention, n_encoder_attention_layers=1, decoder_attention_type=tl.CausalAttention, n_decoder_attention_layers=2, axial_pos_shape=None, d_axial_pos_embs=None): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) dropout_shared_axes: axes on which to share dropout mask max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer ff_dropout: Stochastic rate (probability) for dropping an activation value when applying dropout after the FF dense layer. ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity ff_sparsity_type: string, if ff_sparsity >0, use SparseFF if ff_sparsity_type=`'1inN'` and use BlockSparseFF if ff_sparsity_type=`'Block'` attention_chunk_size: int, if > 0 run attention chunked at this size encoder_attention_type: The attention layer to use for the encoder part. n_encoder_attention_layers: int, within each encoder block, how many attention layers to have. decoder_attention_type: The attention layer to use for the encoder-decoder attention. n_decoder_attention_layers: int, within each decoder block, how many attention layers to have. axial_pos_shape: tuple of ints: input shape to use for the axial position encoding. If unset, axial position encoding is disabled. d_axial_pos_embs: tuple of ints: depth of position embedding for each axis. Tuple length must match axial_pos_shape, and values must sum to d_model. Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_encoder, out_encoder, output_vocab_size = ( ct.EmbeddingAndPositionalEncodings(input_vocab_size, d_model, mode, dropout, dropout_shared_axes, max_len, output_vocab_size=output_vocab_size, axial_pos_shape=axial_pos_shape, d_axial_pos_embs=d_axial_pos_embs)) # pylint: disable=g-complex-comprehension encoder_blocks = [ ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, encoder_attention_type, n_encoder_attention_layers) for i in range(n_encoder_layers) ] # pylint: enable=g-complex-comprehension encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm()) if mode == 'predict': encoder = tl.Cache(encoder) # pylint: disable=g-complex-comprehension decoder_blocks = [ ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation, ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity, ff_sparsity_type, attention_chunk_size, decoder_attention_type, n_decoder_attention_layers) for i in range(n_decoder_layers) ] # pylint: enable=g-complex-comprehension # Assemble and return the model. return tl.Serial( # Input: encoder_side_tokens, decoder_side_tokens # Copy decoder tokens for use in loss. tl.Select([0, 0, 1, 1]), # tok_e tok_e tok_d tok_d # Encode. tl.Branch([], tl.PaddingMask()), # tok_e mask_e tok_e tok_d tok_d encoder, # vec_e mask_e tok_e tok_d tok_d # Simple encoder mask, doesn't contain extra dims. tl.Select([2, 0, 2], n_in=3), # tok_e vec_e tok_e tok_d tok_d tl.Fn( 'EncoderMask', # mask_e vec_e tok_e tok_d tok_d lambda x: x != 0, n_out=1), # Decode. tl.Select([3, 1, 0, 2]), # tok_d vec_e mask_e tok_e tok_d tl.ShiftRight(mode=mode), # stok_d vec_e mask_e tok_e tok_d out_encoder, # svec_d vec_e mask_e tok_e tok_d # Concat encoder and decoder. tl.Select([1, 0]), # vec_e svec_d mask_e tok_e tok_d ConcatWithPadding(mode=mode), # vec_ed tok_e tok_d # Decoder blocks with causal attention decoder_blocks, # vec_ed tok_e tok_d tl.LayerNorm(), # vec_ed tok_e tok_d # Separate out the encoder part from the concatenated vector. tl.Select([0, 1, 2, 2]), # vec_ed tok_e tok_d tok_d StripFromConcatenateWithPadding(mode=mode), # vec_d tok_d # Map to output vocab. tl.Dense(output_vocab_size), # vec_d tok_d tl.LogSoftmax(), # vec_d tok_d )
def test_add_div(self): layer = tl.Branch(tl.Add(), DivideBy(0.5)) xs = [np.array([1, 2, 3]), np.array([10, 20, 30])] ys = layer(xs) self.assertEqual(as_list(ys), [[11, 22, 33], [2, 4, 6]])
def Transformer(input_vocab_size, output_vocab_size=None, d_model=512, d_ff=2048, n_encoder_layers=6, n_decoder_layers=6, n_heads=8, dropout=0.1, max_len=2048, mode='train', ff_activation=tl.Relu): """Returns a Transformer model. This model expects an input pair: target, source. Args: input_vocab_size: int: vocab size of the source. output_vocab_size: int (optional): vocab size of the target. If None, the source and target are assumed to have the same vocab. d_model: int: depth of embedding d_ff: int: depth of feed-forward layer n_encoder_layers: int: number of encoder layers n_decoder_layers: int: number of decoder layers n_heads: int: number of attention heads dropout: float: dropout rate (how much to drop out) max_len: int: maximum symbol length for positional encoding mode: str: 'train' or 'eval' ff_activation: the non-linearity in feed-forward layer Returns: A Transformer model as a layer that maps from a target, source pair to activations over a vocab set. """ in_embed = [ # tokens tl.Embedding(d_model, input_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] if output_vocab_size is None: output_vocab_size = input_vocab_size out_embed = in_embed else: out_embed = [ # tokens tl.Embedding(d_model, output_vocab_size), # vecs tl.Dropout(rate=dropout, mode=mode), # vecs tl.PositionalEncoding(max_len=max_len), # vecs ] encoder_stack = ( # masks vectors --> masks vectors [ EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_encoder_layers) ]) encoder_decoder_stack = ( # vecs_d masks vecs_e --> vecs_d masks vecs_e [ EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode, ff_activation) for i in range(n_decoder_layers) ]) # Input: encoder_side_tokens, decoder_side_tokens return tl.Serial( # tokens_e tokens_d tl.Parallel([], tl.Dup()), # toks_e toks_d toks_d (for loss) tl.Swap(), # toks_d toks_e .... # Encode. tl.Parallel( # toks_d toks_e [], [ tl.Branch(in_embed, tl.PaddingMask()), # ______ vecs_e masks encoder_stack, # ______ vecs_e masks tl.LayerNorm(), # ______ vecs_e ..... tl.Swap() ]), # ______ masks vecs_e # Decode. # toks_d masks vecs_e tl.ShiftRight(), # toks_d ..... ...... out_embed, # vecs_d ..... ...... tl.Branch([], tl.EncoderDecoderMask()), # vecs_d masks ...... encoder_decoder_stack, # vecs_d masks vecs_e tl.Parallel([], tl.Drop(), tl.Drop()), # vecs_d tl.LayerNorm(), # vecs_d tl.Dense(output_vocab_size), # vecs_d tl.LogSoftmax(), # vecs_d )