Пример #1
0
 def test_call_and_grad(self):
     layer_partial = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
     )
     layer = tl.Serial(
         tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
         sparsity.Favor(d_feature=4, n_heads=2),
         tl.Select([0], n_in=2),
         tl.WeightedCategoryCrossEntropy(),
     )
     x = np.ones((1, 2), dtype=np.int32)
     w = np.ones_like(x).astype(np.float32)
     x_sig = shapes.signature(x)
     w_sig = shapes.signature(w)
     layer_partial.init(x_sig)
     y = layer_partial(x)
     self.assertEqual(y.shape, (1, 2, 4))
     layer.init((x_sig, x_sig, w_sig))
     y = layer((x, x, w))
     self.assertEqual(y.shape, ())
     state = layer.state
     rng = fastmath.random.get_prng(0)
     fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[
         0]
     g = fastmath.grad(fwd)(layer.weights, (x, x, w))
     self.assertEqual(g[0][1][0].shape, (3, 4))
Пример #2
0
    def test_funnel_block_forward_shape(self):
        n_even = 4
        d_model = 8

        x = np.ones((1, n_even, d_model), dtype=np.float)
        mask = np.ones((1, n_even), dtype=np.int32)

        masker = tl.PaddingMask()
        mask = masker(mask)

        block = tl.Serial(
            ft._FunnelBlock(d_model,
                            8,
                            2,
                            0.1,
                            None,
                            'train',
                            tl.Relu,
                            tl.AvgPool, (2, ), (2, ),
                            separate_cls=True))

        xs = [x, mask]
        _, _ = block.init(shapes.signature(xs))

        y, _ = block(xs)

        self.assertEqual(y.shape, (1, n_even // 2, d_model))
Пример #3
0
def TransformerEncoder(vocab_size=vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu,
                       EncoderBlock=EncoderBlock):
    """
    Returns a Transformer encoder model.
    The input to the model is a tensor of tokens.

    Args:
        vocab_size (int): vocab size. Defaults to vocab_size.
        n_classes (int): how many classes on output. Defaults to 10.
        d_model (int): depth of embedding. Defaults to 512.
        d_ff (int): depth of feed-forward layer. Defaults to 2048.
        n_layers (int): number of encoder/decoder layers. Defaults to 6.
        n_heads (int): number of attention heads. Defaults to 8.
        dropout (float): dropout rate (how much to drop out). Defaults to 0.1.
        dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None.
        max_len (int): maximum symbol length for positional encoding. Defaults to 2048.
        mode (str): 'train' or 'eval'. Defaults to 'train'.
        ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu.
        EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock.

    Returns:
        trax.layers.combinators.Serial: A Transformer model as a layer that maps
        from a tensor of tokens to activations over a set of output classes.
    """

    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    # repeatation of Encoder block upto number of layers
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation) for _ in range(n_layers)
    ]

    # Encoder Model
    return tl.Serial(
        tl.Branch(
            positional_encoder,
            tl.PaddingMask(),
        ),
        encoder_blocks,
        tl.Select([0], n_in=2),
        tl.LayerNorm(),
        tl.Mean(axis=1),
        tl.Dense(n_classes),
        tl.LogSoftmax(),
    )
Пример #4
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    )
Пример #5
0
 def test_padding_mask(self):
     layer = tl.PaddingMask()
     x = np.array([
         [1., 2., 3., 4., 0.],
         [1., 2., 3., 0., 0.],
         [1., 2., 0., 0., 0.],
     ])
     y = layer(x)
     self.assertEqual(x.shape, (3, 5))
     self.assertEqual(y.shape, (3, 1, 1, 5))
     np.testing.assert_equal(y, [[[[True, True, True, True, False]]],
                                 [[[True, True, True, False, False]]],
                                 [[[True, True, False, False, False]]]])
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Serial(  #      tokens
        tl.Dup(),  # toks toks
        tl.Parallel(embedder, tl.PaddingMask()),  # vecs mask
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                         ff_activation) for i in range(n_layers)
        ],  # vecs mask
        tl.Parallel([], tl.Drop()),  # ____  0
        tl.LayerNorm(),  # vecs
        tl.Mean(axis=1),  # Average on length.    # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    )
Пример #7
0
 def _inp_layers():
     if input_vocab_size is not None:
         return tl.AssertFunction(
             'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
             tl.Serial(  # tok_e tok_d
                 tl.Select([0, 0, 0, 1]),
                 tl.Parallel(
                     in_encoder,
                     [tl.PaddingMask(), _RemoveAxes12()
                      ])))  # vec_e mask_e tok_e tok_d
     else:
         # Input in this case is vec_e, mask_e, tok_d. Where all downstream
         # operations expect tok_e, we give it instead mask_e, expecting that
         # downstream ops only are looking for padding/not padding.
         return tl.AssertFunction(
             'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
             tl.Serial(  # vec_e mask_e tok_d
                 tl.Select([0, 1, 1, 2]),
                 tl.Parallel(in_encoder, [],
                             _AsTokenIDs())))  # vec_e mask_e tok_e tok_d
Пример #8
0
def ConfigurableTransformerEncoder(vocab_size,
                                   n_classes=10,
                                   d_model=512,
                                   d_ff=2048,
                                   n_layers=6,
                                   n_heads=8,
                                   max_len=2048,
                                   dropout=0.1,
                                   dropout_shared_axes=None,
                                   mode='train',
                                   ff_activation=tl.Relu,
                                   ff_dropout=0.1,
                                   ff_chunk_size=0,
                                   ff_use_sru=0,
                                   ff_sparsity=0,
                                   ff_sparsity_type='1inN',
                                   attention_chunk_size=0,
                                   attention_type=tl.Attention,
                                   pos_type=None,
                                   pos_axial_shape=None,
                                   pos_d_axial_embs=None):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor should
      be an integer in `range(vocab_size)`. These integers typically represent
      token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
      classification.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
      residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
      pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder block;
      must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use for the encoder part.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        PositionalEncoder(mode, dropout, max_len, pos_type, pos_axial_shape,
                          pos_d_axial_embs)
    ]

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, attention_type)
        for i in range(n_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
Пример #9
0
def ConfigurableTransformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            max_len=2048,
                            dropout=0.1,
                            dropout_shared_axes=None,
                            mode='train',
                            ff_activation=tl.Relu,
                            ff_dropout=0.1,
                            ff_chunk_size=0,
                            ff_use_sru=0,
                            ff_sparsity=0,
                            ff_sparsity_type='1inN',
                            loss_sparsity_type='mult',
                            loss_sparsity=0,
                            loss_d_lowrank=0,
                            loss_sparsity_prob=None,
                            attention_chunk_size=0,
                            encoder_attention_type=tl.Attention,
                            encoder_decoder_attention_type=tl.CausalAttention,
                            pos_type=None,
                            pos_axial_shape=None,
                            pos_d_axial_embs=None,
                            enc_dec_attention_sparsity=0):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    ff_activation: Type of activation function at the end of each
      encoder/decoder block; must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    encoder_decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    enc_dec_attention_sparsity: int, if > 0 use this sparsity in attention.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        EmbeddingAndPositionalEncodings(input_vocab_size,
                                        d_model,
                                        mode,
                                        dropout,
                                        dropout_shared_axes,
                                        max_len,
                                        output_vocab_size=output_vocab_size,
                                        pos_type=pos_type,
                                        pos_axial_shape=pos_axial_shape,
                                        pos_d_axial_embs=pos_d_axial_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, encoder_attention_type)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                            dropout_shared_axes, mode, ff_activation,
                            ff_dropout, ff_chunk_size, ff_use_sru, ff_sparsity,
                            ff_sparsity_type, attention_chunk_size,
                            encoder_decoder_attention_type,
                            enc_dec_attention_sparsity)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.SparseDenseWithOptions(  # vec_d .....
            output_vocab_size,
            d_input=d_model,
            sparsity_type=loss_sparsity_type,
            sparsity=loss_sparsity,
            d_lowrank=loss_d_lowrank,
            prob_sparse=loss_sparsity_prob,
            mode=mode),
    )
Пример #10
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='fixed-base',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if d_attention_key is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_value = d_model // n_heads

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            axial_pos_shape=axial_pos_shape,
            d_axial_pos_embs=d_axial_pos_embs,
            use_bfloat16=use_bfloat16))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     attention_chunk_size=attention_chunk_size,
                     use_bfloat16=use_bfloat16,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = [  # vec_e mask_e tok_e tok_d tok_d
        tl.ReversibleSelect([0, 0]),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget)
    ]
    if not reversible_encoder:
        encoder += [
            tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
            tl.Dense(d_model, use_bfloat16=use_bfloat16),
            tl.LayerNorm(),
        ]
    encoder = tl.Serial(encoder)
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(
            d_model,
            d_ff,
            d_attention_key,
            d_attention_value,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)

    dense_loss_layer = tl.SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

    # Layers to merge encoder and decoder, see below for details.
    if reversible_encoder:
        encdec_layers = [
            tl.ReversibleSelect([0, 1, 4, 2,
                                 3]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding2(mode=mode),  # vec_ed vec_ed tok_e tok_d
        ]
    else:
        encdec_layers = [
            tl.ReversibleSelect([0, 3, 1,
                                 2]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d
            tl.ReversibleSelect([0, 0]),  # vec_ed vec_ed tok_e tok_d
        ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 0, 1, 1]),  # tok_e tok_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            in_encoder,
            [],
            [],  # vec_e tok_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),
        tl.Parallel([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # vec_e mask_e tok_e vec_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Concat encoder and decoder, given encoder mask.
        encdec_layers,

        # Run decoder blocks.
        _ReversibleSerialForget(
            decoder_blocks, d_model,
            n_layers_forget),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        dense_loss_layer,  # vec_d tok_d
    )
Пример #11
0
def TransformerEncoder(vocab_size=vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu,
                      EncoderBlock=EncoderBlock):
    
    """
    Returns a Transformer encoder model.
    The input to the model is a tensor of tokens.
  
    Args:
        vocab_size (int): vocab size. Defaults to vocab_size.
        n_classes (int): how many classes on output. Defaults to 10.
        d_model (int): depth of embedding. Defaults to 512.
        d_ff (int): depth of feed-forward layer. Defaults to 2048.
        n_layers (int): number of encoder/decoder layers. Defaults to 6.
        n_heads (int): number of attention heads. Defaults to 8.
        dropout (float): dropout rate (how much to drop out). Defaults to 0.1.
        dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None.
        max_len (int): maximum symbol length for positional encoding. Defaults to 2048.
        mode (str): 'train' or 'eval'. Defaults to 'train'.
        ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu.
        EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock.
  
    Returns:
        trax.layers.combinators.Serial: A Transformer model as a layer that maps
        from a tensor of tokens to activations over a set of output classes.
    """
    
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]
    
    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###
    
    # Use the function `EncoderBlock` (implemented above) and pass in the parameters over `n_layers`
    encoder_blocks = [EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode, ff_activation) for _ in range(n_layers)]

    # Assemble and return the model.
    return tl.Serial(
        # Encode
        tl.Branch(
            # Use `positional_encoder`
            positional_encoder,
            # Use trax padding mask
            tl.PaddingMask(),
        ),
        # Use `encoder_blocks`
        encoder_blocks,
        # Use select layer
        tl.Select([0], n_in=2),
        # Use trax layer normalization
        tl.LayerNorm(),
        # Map to output categories.
        # Use trax mean. set axis to 1
        tl.Mean(axis=1),
        # Use trax Dense using `n_classes`
        tl.Dense(n_classes),
        # Use trax log softmax
        tl.LogSoftmax(),
    )
Пример #12
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              pos_start_from_zero_prob=1.0,
              pos_max_offset_to_add=0,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              forget_dense=True,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              center_layernorm=True,
              half_before_layer=None,
              double_after_layer=None,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    forget_dense: whether to use Dense or no-op (Serial) as a forget layer.
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    half_before_layer: int, half d_model and d_ff before that layer
    double_after_layer: int, double d_model and d_ff after that layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if (d_model / 2) % n_heads != 0:
        raise ValueError(
            f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
    if d_attention_key is None:
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        d_attention_value = d_model // n_heads

    # Set values of d_model, d_ff and d_qkv for the first stage.
    d_model1, d_ff1 = d_model, d_ff
    d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
    if half_before_layer:
        d_model1, d_ff1 = d_model / 2, d_ff / 2
        d_attention_key1 = d_attention_key / 2
        d_attention_value1 = d_attention_value / 2

    # Set values of d_model, d_ff and d_qkv for the final stage.
    d_model2, d_ff2 = d_model, d_ff
    d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
    if double_after_layer:
        d_model2, d_ff2 = d_model * 2, d_ff * 2
        d_attention_key2 = d_attention_key * 2
        d_attention_value2 = d_attention_value * 2

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model1,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs,
            pos_start_from_zero_prob=pos_start_from_zero_prob,
            pos_max_offset_to_add=pos_max_offset_to_add,
            use_bfloat16=use_bfloat16))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model1,
                     d_ff1,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     attention_chunk_size=attention_chunk_size,
                     center_layernorm=center_layernorm,
                     use_bfloat16=use_bfloat16,
                     use_two_swaps_per_block=use_two_swaps_per_encoder_block,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = [  # vec_e mask_e tok_e tok_d tok_d
        tl.ReversibleSelect([0, 0]),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model1, n_layers_forget,
                                forget_dense)
    ]
    if not reversible_encoder:
        encoder += [
            tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
            tl.Dense(d_model1, use_bfloat16=use_bfloat16),
            tl.LayerNorm(),
        ]
    encoder = tl.Serial(encoder)
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        # Grow d_model, d_ff, and d_qkv if requested.
        d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
        if half_before_layer and layer_idx >= half_before_layer:
            d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
        if double_after_layer and layer_idx > double_after_layer:
            d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
        decoder_block = DecoderBlock(
            d_m,
            d_f,
            d_k,
            d_v,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)
        if half_before_layer and layer_idx == half_before_layer - 1:
            decoder_blocks.append(tl.ReversibleConcatenatePair())
        if double_after_layer and layer_idx == double_after_layer:
            decoder_blocks.append(tl.ReversibleConcatenatePair())

    dense_loss_layer = tl.SparseDenseWithOptions(
        output_vocab_size,
        d_input=d_model2,
        sparsity_type=loss_sparsity_type,
        sparsity=loss_sparsity,
        d_lowrank=loss_d_lowrank,
        prob_sparse=loss_sparsity_prob,
        use_bfloat16=use_bfloat16,
        mode=mode)

    # Layers to merge encoder and decoder, see below for details.
    if reversible_encoder:
        encdec_layers = [
            tl.ReversibleSelect([0, 1, 4, 2,
                                 3]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding2(mode=mode),  # vec_ed vec_ed tok_e tok_d
        ]
    else:
        encdec_layers = [
            tl.ReversibleSelect([0, 3, 1,
                                 2]),  # vec_e vec_d mask_e tok_e tok_d
            t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d
            tl.ReversibleSelect([0, 0]),  # vec_ed vec_ed tok_e tok_d
        ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 0, 1, 1]),  # tok_e tok_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            in_encoder,
            [],
            [],  # vec_e tok_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),

        # Predict mode doesn't work with padding in encoder. Raising an exception
        # in jitted function isn't possible, so the second next best thing is
        # to convert every embedding to NaNs, so the user will not get subtly
        # wrong results, but clearly wrong results.
        (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []),
        tl.Parallel([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # vec_e mask_e tok_e vec_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Concat encoder and decoder, given encoder mask.
        encdec_layers,

        # Run decoder blocks.
        _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                                forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        dense_loss_layer,  # vec_d tok_d
    )
Пример #13
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train',
             axial_pos_shape=None,
             d_axial_pos_embs=None,
             ff_use_sru=0,
             ff_chunk_size=0,
             ff_sparsity=0):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            axial_pos_shape=axial_pos_shape,
            d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     tl.SelfAttention,
                     dropout,
                     ff_activation,
                     ff_dropout,
                     mode=mode,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model,
                            d_ff,
                            n_heads,
                            dropout,
                            ff_activation,
                            ff_dropout,
                            mode,
                            ff_use_sru=ff_use_sru,
                            ff_chunk_size=ff_chunk_size,
                            ff_sparsity=ff_sparsity)
        for _ in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
    )
Пример #14
0
def TransformerNoEncDecAttention(input_vocab_size,
                                 output_vocab_size=None,
                                 d_model=512,
                                 d_ff=2048,
                                 n_encoder_layers=6,
                                 n_decoder_layers=6,
                                 n_heads=8,
                                 dropout=0.1,
                                 dropout_shared_axes=None,
                                 max_len=2048,
                                 mode='train',
                                 ff_activation=tl.Relu):
  """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  def PositionalEncoder(vocab_size):  # tokens --> vectors
    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      transformer._EncoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder = tl.Serial(
      in_encoder,
      encoder_blocks,
      tl.LayerNorm()
  )
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = [
      transformer._DecoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # pylint: disable=protected-access
  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),          # tok_e tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
      encoder,                          # vec_e mask_e tok_e tok_d tok_d

      # Simple encoder mask, doesn't contain extra dims.
      tl.Select([2, 0, 2], n_in=3),     # tok_e vec_e tok_e tok_d tok_d
      transformer._MaskOfRightShiftedArray(
          n_positions=0),               # mask_e vec_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 1, 0, 2]),          #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),         # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          transformer._MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder.
      tl.Select([2, 0, 3, 1]),          # vec_e svec_d mask_e mask_d tok_e tok_d
      transformer._ConcatWithPadding(),  # vec_ed tok_e tok_d

      # Decoder blocks with causal attention
      decoder_blocks,                   # vec_ed tok_e tok_d
      tl.LayerNorm(),                   # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),          # vec_ed tok_e tok_d tok_d
      transformer._StripFromConcatenateWithPadding(),  # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),      # vec_d tok_d
      tl.LogSoftmax(),                  # vec_d tok_d
  )
Пример #15
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type, dropout,
          ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # tok_e mask_e tok_e tok_d tok_d
      in_encoder,                      # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),                  # tok_e tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # tok_e mask_e tok_e tok_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),                # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given their masks.
      tl.Select([2, 0, 3, 1]),          # svec_d mask_d vec_e mask_e tok_e tok_d
      _ConcatWithPadding(),                        # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      tl.ReversibleSerial(decoder_blocks),         # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),          # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),                 # vec_d tok_d
      tl.LogSoftmax(),                             # vec_d tok_d
  )
Пример #16
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        # TODO(kitaev): axial positional encoding is better for very long sequences.
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
            positional_encoding,
        ]

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     tl.SelfAttention,
                     dropout,
                     ff_activation,
                     ff_dropout,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                            ff_dropout, mode) for _ in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Пример #17
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax
    # Transformer, but may not be the right option in general.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=0.0, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        # TODO(kitaev): BroadcastedDropout?
        tl.Dropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_encoder_layers)]

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch(
          in_encoder, [tl.PaddingMask(),
                       tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]
          ),                                # vec_e  mask  tok_d .....
      tl.Dup(),                             # vec_e1 vec_e2 mask tok_d .....
      tl.ReversibleSerial(encoder_blocks),  # vec_e1 vec_e2 mask tok_d .....
      # The two sets of activations need to be reduced to one, in this case by
      # averaging them. Note that ReformerLM concatenates instead. Various
      # options (concat, average, add, keep only one, etc.) seem to perform
      # similarly. We don't concatenate here because we want exact parameter
      # parity with the standard Transformer.
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_e  mask tok_d .....
      tl.LayerNorm(),                       # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(),                      # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
Пример #18
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
  """Returns a Transformer model.

  This model expects an input pair: source, target.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a source, target pair to
    activations over a vocab set.
  """
  def PositionalEncoder(vocab_size):  # tokens --> vectors
    return [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      _EncoderBlock(
          d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder = tl.Serial(
      in_encoder,
      encoder_blocks,
      tl.LayerNorm()
  )
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      _EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),    # tok_e masks ..... .....
      encoder,                            # vec_e ..... ..... .....

      # Decode.
      tl.Select([2, 1, 0]),               # tok_d masks vec_e .....
      tl.ShiftRight(),                    # tok_d ..... ..... .....
      out_encoder,                        # vec_d ..... ..... .....
      tl.Branch(
          [], tl.EncoderDecoderMask()),   # vec_d masks ..... .....
      encoder_decoder_blocks,             # vec_d masks ..... .....
      tl.LayerNorm(),                     # vec_d ..... ..... .....

      # Map to output vocab.
      tl.Select([0], n_in=3),             # vec_d tok_d
      tl.Dense(output_vocab_size),        # vec_d .....
      tl.LogSoftmax(),                    # vec_d .....
  )
Пример #19
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                max_len=2048,
                dropout=0.1,
                dropout_shared_axes=None,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if None, then input and target integers (token IDs) are assumed to come
        from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
        block will include dropout; else, it will pass all values through
        unaltered.
    ff_activation: Type of activation function at the end of each
        encoder/decoder block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    # Positional encodings are not shared between encoder and decoder.
    # Since encoder doesn't run stepwise, we do not use predict mode there.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=encoder_mode)
    ]
    out_encoder = out_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=mode)
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)
        for i in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
    )
Пример #20
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       max_len=2048,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
        classification.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
Пример #21
0
def FunnelTransformerEncoder(vocab_size,
                             n_classes=10,
                             d_model=512,
                             d_ff=2048,
                             encoder_segment_lengths=(2, 2, 2),
                             n_heads=8,
                             max_len=2048,
                             dropout=0.1,
                             dropout_shared_axes=None,
                             mode='train',
                             ff_activation=tl.Relu,
                             pool_layer=tl.AvgPool,
                             pool_size=(2,),
                             strides=(2,),
                             separate_cls=True):
  """Returns a Funnel Encoder.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
        classification.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    encoder_segment_lengths: Tuple, where each element denotes the number of
        transformer encoder blocks preceding a funnel transformer block.
        There is no funnel block after the last sequence of encoder blocks,
        therefore the total number of blocks in the model is equal to
        `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling in each of the
        funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
        embeddings of the first token (`cls` from BERT paper) and only final
        embedding of this token is used for categorization - the rest are
        discarded. If `False`, each token from the beginning is pooled and
        all embeddings are averaged and mapped to output categories like in
        original `TransformerEncoder` model.
  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
  assert encoder_segment_lengths

  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      tl.PositionalEncoding(max_len=max_len)]

  encoder_blocks = []
  n_encoder_segments = len(encoder_segment_lengths)

  for i in range(n_encoder_segments):
    # Building i'th segment
    for _ in range(encoder_segment_lengths[i]):
      # Create segment_size encoder blocks
      encoder_blocks.append(
          _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation))

    # If not last segment, add funnel block
    if i != n_encoder_segments - 1:
      encoder_blocks.append(
          _FunnelBlock(d_model, d_ff, n_heads, dropout,
                       dropout_shared_axes, mode,
                       ff_activation, pool_layer, pool_size,
                       strides, separate_cls))

  cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1)

  # Assemble and return the model.
  return tl.Serial(                               # toks
      # Encode.
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks,                             # vecs masks
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),                             # vecs

      # Map to output categories.
      cls_pooling,                                # cls
      tl.Dense(n_classes),                        # cls
  )
Пример #22
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer-style encoder.

  For each item in a batch, this model performs a sequence-to-sequence mapping:

    - input: sequence of integers, usually token id's from a fixed-size
      vocabulary -- integers in `range(M)`, where `M` is the vocabulary
      size.

    - output:  same-length sequence of N-dimensional vectors, where each vector
      can be interpreted as a log-probability distribution over N discrete
      categories.

  Args:
    vocab_size: "Vocabulary size" -- input integer id's must be in
        `range(vocab_size)`. Id's typically come from preprocessing text data
        with a vocabulary-based tokenizer.
    n_classes: Size/depth of the output vectors, intended for an N-way
        classification task.
    d_model: The basic embedding size (vector depth) of the model. This is the
        vector size used by the initial embedding layer and at many intermediate
        points in the model.
    d_ff: Vector depth (typically greater than `d_model`) used in the
        feed-forward (`Dense`) layer of each encoder block.
    n_layers: Number of encoder blocks. Each encoder block includes attention,
        dropout, residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
    max_len: Maximum symbol length for positional encoding.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: The activation function (layer) at the end of each encoder
        block.

  Returns:
    A Transformer model as a layer that maps from token id's to activations
    over a set of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    )
Пример #23
0
def FunnelTransformer(vocab_size,
                      d_model=512,
                      d_ff=2048,
                      encoder_segment_lengths=(2, 2, 2),
                      n_decoder_blocks=2,
                      n_heads=8,
                      max_len=2048,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      mode='train',
                      ff_activation=tl.Relu,
                      pool_layer=tl.AvgPool,
                      pool_size=(2,),
                      separate_cls=True):
  """Returns a Full Funnel Transformer, that can be used for example for BERT.

  This model outputs token-level categorical distributions over all vocab:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions over `vocab_size` categories for each token; shape is
      (batch_size, sequence_length, vocab_size).


  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    encoder_segment_lengths: Tuple, where each element denotes the number of
        transformer encoder blocks preceding a funnel transformer block.
        There is no funnel block after the last sequence of encoder blocks,
        therefore the total number of blocks in the model is equal to
        `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`.
    n_decoder_blocks: Number of transformer blocks in the upsampling decoder.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling in each of the
        funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
        embeddings of the first token (`cls` from BERT paper) and only final
        embedding of this token is used for categorization - the rest are
        discarded. If `False`, each token from the beginning is pooled and
        all embeddings are averaged and mapped to output categories like in
        original `TransformerEncoder` model.
  """
  assert encoder_segment_lengths

  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      tl.PositionalEncoding(max_len=max_len)]

  n_encoder_segments = len(encoder_segment_lengths)

  encoder_blocks_before_first_pooling = [
      _EncoderBlock(d_model, d_ff, n_heads, dropout,
                    dropout_shared_axes, mode, ff_activation)
      for _ in range(encoder_segment_lengths[0])]
  encoder_blocks_from_first_pooling = []

  for i in range(1, n_encoder_segments):
    # Building i'th segment

    # Add funnel block between segments
    encoder_blocks_from_first_pooling.append(
        _FunnelBlock(d_model, d_ff, n_heads, dropout,
                     dropout_shared_axes, mode,
                     ff_activation, pool_layer,
                     pool_size=pool_size, strides=pool_size,
                     separate_cls=separate_cls))

    for _ in range(encoder_segment_lengths[i]):
      # Create segment_size encoder blocks
      encoder_blocks_from_first_pooling.append(
          _EncoderBlock(d_model, d_ff, n_heads, dropout,
                        dropout_shared_axes, mode, ff_activation))

  decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout,
                                  dropout_shared_axes, mode, ff_activation)
                    for _ in range(n_decoder_blocks)]

  total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1)

  # Assemble and return the model.
  return tl.Serial(                               # toks
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks_before_first_pooling,        # vecs masks
      tl.Select([0, 1, 0, 1]),
      # vecs masks residual = vecs old_masks
      encoder_blocks_from_first_pooling,          # vecs masks residual masks
      tl.Select([0, 2, 3]),                       # vecs residual masks
      tl.Parallel(
          # residual from first segment is taken before
          # normalization, so apply it now
          None, tl.LayerNorm(), None),            # vecs norm(residual) masks
      _Upsampler(total_pool_size, separate_cls),  # vecs masks
      decoder_blocks,
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),
      tl.Dense(vocab_size),
  )
Пример #24
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_encoder_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_decoder_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Serial(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Пример #25
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]

  encoder = tl.Serial([
      in_encoder,
      tl.Dup(),
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),                 # tok_e tok_d tok_d
      tl.Branch([], [                       # tok_e mask  tok_d .....
          tl.PaddingMask(),
          tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]),

      # Encode.
      encoder,                              # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(mode=mode),             # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
Пример #26
0
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='fixed-base',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              n_layers_forget=0,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    n_layers_forget: how often to have a forgetting block between layers
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if d_attention_key is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_value = d_model // n_heads

    # Vector embeddings.
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    def PositionalEnc(mode):
        return PositionalEncoding(mode, dropout, max_len, axial_pos_shape,
                                  d_axial_pos_embs)

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [PositionalEnc(encoder_mode)]
    out_encoder = out_embedder + [PositionalEnc(mode)]
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.Dense(d_model),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=ff_dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 0, 1, 1]),  # tok_e tok_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            in_encoder,
            [],
            [],  # vec_e tok_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),
        tl.Parallel([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # vec_e mask_e tok_e vec_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  vec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given encoder mask.
        tl.Select([1, 0]),  # vec_e vec_d mask_e tok_e tok_d
        t2.ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        _ReversibleSerialForget(
            decoder_blocks, d_model,
            n_layers_forget),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
Пример #27
0
def Transformer2(input_vocab_size,
                 output_vocab_size=None,
                 d_model=512,
                 d_ff=2048,
                 n_encoder_layers=6,
                 n_decoder_layers=6,
                 n_heads=8,
                 dropout=0.1,
                 dropout_shared_axes=None,
                 max_len=2048,
                 mode='train',
                 ff_activation=tl.Relu,
                 ff_dropout=0.1,
                 ff_chunk_size=0,
                 ff_use_sru=0,
                 ff_sparsity=0,
                 ff_sparsity_type='1inN',
                 attention_chunk_size=0,
                 encoder_attention_type=tl.Attention,
                 n_encoder_attention_layers=1,
                 decoder_attention_type=tl.CausalAttention,
                 n_decoder_attention_layers=2,
                 axial_pos_shape=None,
                 d_axial_pos_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    n_encoder_attention_layers: int, within each encoder block, how many
      attention layers to have.
    decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    n_decoder_attention_layers: int, within each decoder block, how many
      attention layers to have.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           axial_pos_shape=axial_pos_shape,
                                           d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, encoder_attention_type,
                        n_encoder_attention_layers)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    decoder_blocks = [
        ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, decoder_attention_type,
                        n_decoder_attention_layers)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Simple encoder mask, doesn't contain extra dims.
        tl.Select([2, 0, 2], n_in=3),  #  tok_e vec_e tok_e tok_d tok_d
        tl.Fn(
            'EncoderMask',  # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0,
            n_out=1),

        # Decode.
        tl.Select([3, 1, 0, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder.
        tl.Select([1, 0]),  # vec_e svec_d mask_e tok_e tok_d
        ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
    )
Пример #28
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=D_MODEL,
                d_ff=D_FF,
                n_encoder_layers=N_LAYERS,
                n_decoder_layers=N_LAYERS,
                n_heads=N_HEADS,
                max_len=MAX_SEQUENCE_LENGTH,
                dropout=DROPOUT_RATE,
                dropout_shared_axes=DROPOUT_SHARED_AXES,
                mode=MODE,
                ff_activation=FF_ACTIVATION_TYPE):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: Array representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length),
          where sequence_length <= ``max_len``. Array elements are integers in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: Array representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length),
          where sequence_length <= ``max_len``. Array elements are integers in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 3-D array of raw activations with last/innermost dimension of
      ``output_vocab_size``, suitable for decoding into a batch of token
      strings; shape is (batch_size, sequence_length, ``vocab_size``).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if ``None``, then input and target integers (token IDs) are assumed to
        come from the same vocabulary.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder/decoder blocks. The same rate is
        also used for attention dropout in encoder/decoder blocks.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``)
        is a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If ``'predict'``, use fast inference. If ``'train'``, each
        encoder/decoder block will include dropout; else, it will pass all
        values through unaltered.
    ff_activation: Type of activation function at the end of each
        encoder/decoder block; must be an activation-type subclass of
        :py:class:`Layer`.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise.
    encoder_mode = 'eval' if mode == 'predict' else mode

    # Share embedding weights if no separate output vocab size.
    in_embedder = tl.Embedding(input_vocab_size, d_model)
    if output_vocab_size is None:
        out_embedder = in_embedder
        output_vocab_size = input_vocab_size
    else:
        out_embedder = tl.Embedding(output_vocab_size, d_model)

    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    def _EncBlock():
        return _EncoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)

    def _Encoder():
        encoder = tl.Serial(
            in_embedder,
            _Dropout(),
            tl.PositionalEncoding(max_len=max_len, mode=encoder_mode),
            [_EncBlock() for _ in range(n_encoder_layers)],
            tl.LayerNorm(),
        )
        return tl.Cache(encoder) if mode == 'predict' else encoder

    def _EncDecBlock():
        return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e
    # Model output is decoder-side vectors and decoder-side tokens: vec_d  tok_d
    return tl.Serial(
        tl.Select([0, 1, 1]),  # Copies decoder tokens for use in loss.

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks tok_d tok_d
        _Encoder(),

        # Decode.
        tl.Select([2, 1, 0]),  # Re-orders inputs: tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),
        out_embedder,
        _Dropout(),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        [_EncDecBlock() for _ in range(n_decoder_layers)],
        tl.LayerNorm(),
        tl.Select([0], n_in=3),  # Drops masks and encoding vectors.

        # Map vectors to match output vocab size.
        tl.Dense(output_vocab_size),
    )
Пример #29
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert
Пример #30
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=D_MODEL,
                       d_ff=D_FF,
                       n_layers=N_LAYERS,
                       n_heads=N_HEADS,
                       max_len=MAX_SEQUENCE_LENGTH,
                       dropout=DROPOUT_RATE,
                       dropout_shared_axes=DROPOUT_SHARED_AXES,
                       mode=MODE,
                       ff_activation=FF_ACTIVATION_TYPE):
    """Returns a Transformer encoder suitable for N-way classification.

  This model maps tokenized text to N-way (``n_classes``) activations:

    - input: Array representing a batch of text strings via token IDs plus
      padding markers; shape is (batch_size, sequence_length), where
      sequence_length <= ``max_len``. Array elements are integers in
      ``range(vocab_size)``, and 0 values mark padding positions.

    - output: Array representing a batch of raw (non-normalized) activations
      over ``n_classes`` categories; shape is (batch_size, ``n_classes``).

  Args:
    vocab_size: Input vocabulary size -- each element of the input array
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Last/innermost dimension of output arrays, suitable for N-way
        classification.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, layer-norm, feedforward (:py:class:`Dense`), and activation
        layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder blocks. The same rate is also
        used for attention dropout in encoder blocks.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``)
        is a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If ``'train'``, each encoder block will include dropout; else, it
        will pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of :py:class:`Layer`.

  Returns:
    A Transformer model that maps strings (conveyed by token IDs) to
    raw (non-normalized) activations over a range of output classes.
  """
    def _Dropout():
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

    def _EncBlock():
        return _EncoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)

    return tl.Serial(
        tl.Branch([],
                  tl.PaddingMask()),  # Creates masks from copy of the tokens.
        tl.Embedding(vocab_size, d_model),
        _Dropout(),
        tl.PositionalEncoding(max_len=max_len),
        [_EncBlock() for _ in range(n_layers)],
        tl.Select([0], n_in=2),  # Drops the masks.
        tl.LayerNorm(),
        tl.Mean(axis=1),
        tl.Dense(n_classes),
    )