コード例 #1
0
ファイル: reformer.py プロジェクト: rizwandel/trax
def EncoderBlock(d_model,
                 d_ff,
                 n_heads,
                 attention_type,
                 dropout,
                 ff_activation,
                 ff_dropout,
                 ff_use_sru=0,
                 ff_chunk_size=0,
                 ff_sparsity=0,
                 attention_chunk_size=0,
                 center_layernorm=True,
                 use_bfloat16=False,
                 use_two_swaps_per_block=True,
                 mode='train'):
    """Returns a list of layers that implements a Reformer encoder block.

  The input to the layer is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: the dropout rate in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    attention_chunk_size: int, if > 0 run attention chunked at this size
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    use_two_swaps_per_block: bool, if True use two reversible swaps in Encoder
      block, otherwise use only one swap.
    mode: str: 'train' or 'eval'

  Returns:
    A list of layers that maps (activations, mask) to (activations, mask).
  """
    if mode == 'predict':
        # Mode 'predict' means that the decoder should be run one token at a time.
        # The encoder only ever runs over full sequences, which is why it's switched
        # to 'eval' mode instead.
        mode = 'eval'

    def _Attn():
        return ct.ApplyAttentionLayer(
            attention_type=attention_type,
            d_model=d_model,
            n_heads=n_heads,
            d_qk=d_model // n_heads,
            d_v=d_model // n_heads,
            masked=True,
            causal=False,
            attention_dropout=dropout,
            output_dropout=dropout,
            attention_chunk_size=attention_chunk_size,
            mode=mode)

    def _FF():
        return ct.FeedForwardWithOptions(d_model, d_ff, dropout, [-2],
                                         ff_activation, ff_dropout,
                                         ff_chunk_size, ff_use_sru,
                                         ff_sparsity, center_layernorm, mode,
                                         use_bfloat16)

    # TODO(lukaszkaiser): refactor efficient attention layers to unify the API
    # If we're using standard attention, we need to pass reshaped mask and not
    # return the mask to be compatible with the EfficientAttention API.
    attention = _Attn()
    if attention.n_out == 2:
        attention = tl.Serial(tl.Parallel([], _InsertAxes12()), attention,
                              tl.Select([0], n_in=2))

    def _attention_half_residual():
        return [
            tl.ReversibleHalfResidual(
                tl.LayerNorm(center=center_layernorm),
                attention_layer=attention,
                name='ReversibleHalfResidualEncoderAttn'),
            tl.ReversibleSwap()
        ]

    def _feed_forward():
        layers = [
            tl.ReversibleHalfResidual(_FF(),
                                      name='ReversibleHalfResidualEncoderFF')
        ]
        if use_two_swaps_per_block:
            layers.append(tl.ReversibleSwap())
        return layers

    return _attention_half_residual() + _feed_forward()
コード例 #2
0
ファイル: transformer.py プロジェクト: yaoshuyin/trax
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       max_len=2048,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
        classification.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
コード例 #3
0
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               n_chunks=0,
               n_attention_chunks=1,
               attention_type=tl.DotProductCausalAttention,
               share_qk=False,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_chunks: int: number of chunks (must match input pipeline)
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    if n_chunks == 0:
        n_chunks = 1
        concatenate_input_chunks = []
    else:
        concatenate_input_chunks = tl.Concatenate(n_items=n_chunks)

    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout)
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout)

    positional_embedder = [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(
            d_model,
            d_ff,
            d_attention_key,
            d_attention_value,
            n_heads,
            n_attention_chunks,
            attention_type=layer_attention_type,
            dropout=dropout,
            share_qk=(share_qk or issubclass(layer_attention_type,
                                             tl.LSHCausalAttention)),
            mode=mode)
        decoder_blocks.append(decoder_block)

    return tl.Serial(
        concatenate_input_chunks,
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial(decoder_blocks + [
            SplitForOutput(n_sections=n_chunks, axis=-2),  # pylint: disable=no-value-for-parameter
        ]),
        Map(
            [
                # TODO(kitaev): Test whether dropout should go before or after the
                # LayerNorm, and whether dropout broadcasting is needed here.
                tl.LayerNorm(),
                BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
                tl.Dense(vocab_size),
                tl.LogSoftmax(),
            ],
            n_sections=n_chunks),
    )
コード例 #4
0
 def ProcessingLayer():
     return tl.Serial(tl.Dense(lowrank), tl.Dense(d_feature))
コード例 #5
0
ファイル: transformer.py プロジェクト: yaoshuyin/trax
def TransformerDecoder(vocab_size=None,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       max_len=2048,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer decoder.

  This model maps sequential inputs to sequential outputs:

    - input if `vocab_size` is specified: rank 2 tensor representing a batch
      of text strings via token IDs plus padding markers; shape is
      (batch_size, sequence_length). The tensor elements are integers in
      `range(vocab_size)`, and `0` values mark padding positions.

    - input if `vocab_size` is None: rank 3 tensor representing a batch
      of activation vectors; shape is (batch_size, sequence_length, `d_model`).

    - output: rank 3 tensor with shape (batch_size, sequence_length, `d_model`).

  The model uses causal attention and does *not* shift the input to the right.
  Thus, the output for position `t` is based on inputs up to and including
  position `t`.

  Args:
    vocab_size: If specified, gives the input vocabulary size -- each element
        of the input tensor should be an integer in `range(vocab_size)`.
        If None, indicates that the model expects as input floating point
        vectors, each with `d_model` components.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each decoder
        block.
    n_layers: Number of decoder blocks. Each block includes attention, dropout,
        residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each decoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each decoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    If `vocab_size` is defined: a Transformer model that maps strings (conveyed
    via token IDs) to sequences of activation vectors.

    If `vocab_size` is None: a Transformer model that maps sequences of
    activation vectors to sequences of activation vectors.
  """
    positional_encoder = [(tl.Embedding(vocab_size, d_model)
                           if vocab_size is not None else tl.Dense(d_model)),
                          tl.Dropout(rate=dropout,
                                     shared_axes=dropout_shared_axes,
                                     mode=mode),
                          tl.PositionalEncoding(max_len=max_len)]

    decoder_blocks = [
        # pylint: disable=g-complex-comprehension
        _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        positional_encoder,  # vecs
        decoder_blocks,  # vecs
        tl.LayerNorm(),  # vecs
    )
コード例 #6
0
def ConfigurableTransformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            max_len=2048,
                            dropout=0.1,
                            dropout_shared_axes=None,
                            mode='train',
                            ff_activation=tl.Relu,
                            ff_dropout=0.1,
                            ff_chunk_size=0,
                            ff_use_sru=0,
                            ff_sparsity=0,
                            ff_sparsity_type='1inN',
                            attention_chunk_size=0,
                            encoder_attention_type=tl.Attention,
                            encoder_decoder_attention_type=tl.CausalAttention,
                            axial_pos_shape=None,
                            d_axial_pos_embs=None):
  """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    ff_activation: Type of activation function at the end of each
      encoder/decoder block; must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    encoder_decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
  in_encoder, out_encoder, output_vocab_size = (
      EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model,
          mode,
          dropout,
          dropout_shared_axes,
          max_len,
          output_vocab_size=output_vocab_size,
          axial_pos_shape=axial_pos_shape,
          d_axial_pos_embs=d_axial_pos_embs)
  )

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                   ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
                   ff_sparsity, ff_sparsity_type,
                   attention_chunk_size, encoder_attention_type)
      for i in range(n_encoder_layers)
  ]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  # pylint: disable=g-complex-comprehension
  encoder_decoder_blocks = [
      EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                          mode, ff_activation, ff_dropout, ff_chunk_size,
                          ff_use_sru, ff_sparsity, ff_sparsity_type,
                          attention_chunk_size, encoder_decoder_attention_type)
      for i in range(n_decoder_layers)
  ]
  # pylint: enable=g-complex-comprehension

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),    # tok_e masks ..... .....
      encoder,                            # vec_e ..... ..... .....

      # Decode.
      tl.Select([2, 1, 0]),               # tok_d masks vec_e .....
      tl.ShiftRight(mode=mode),           # tok_d ..... ..... .....
      out_encoder,                        # vec_d ..... ..... .....
      tl.Branch(
          [], tl.EncoderDecoderMask()),   # vec_d masks ..... .....
      encoder_decoder_blocks,             # vec_d masks ..... .....
      tl.LayerNorm(),                     # vec_d ..... ..... .....

      # Map to output vocab.
      tl.Select([0], n_in=3),             # vec_d tok_d
      tl.Dense(output_vocab_size),        # vec_d .....
      tl.LogSoftmax(),                    # vec_d .....
  )
コード例 #7
0
ファイル: trainer_lib.py プロジェクト: dioptre/trax
        'model_state',  # Auxilliary state of the model.
    ])

OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.CrossEntropyLoss(),
    'accuracy': tl.Accuracy(),
    'sequence_accuracy': tl.SequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.CrossEntropyLoss(), tl.Negate()),
    'weights_per_batch_per_core': tl.SumOfWeights(),
}


class Trainer(object):
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
コード例 #8
0
def BERT(
    d_model=768,
    vocab_size=30522,
    max_len=512,
    type_vocab_size=2,
    n_heads=12,
    d_ff=3072,
    n_layers=12,
    head=None,
    init_checkpoint=None,
    mode='eval',
):
    """BERT (default hparams are for bert-base-uncased)."""
    layer_norm_eps = 1e-12
    d_head = d_model // n_heads

    word_embeddings = tl.Embedding(vocab_size, d_model)
    type_embeddings = tl.Embedding(type_vocab_size, d_model)
    position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
    embeddings = [
        tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
        tl.Parallel(word_embeddings, type_embeddings, [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)
        ]),
        tl.Add(),
        position_embeddings,
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

    encoder = []
    for _ in range(n_layers):
        attn = tl.SelfAttention(n_heads=n_heads,
                                d_qk=d_head,
                                d_v=d_head,
                                bias=True,
                                masked=True,
                                mode=mode)
        feed_forward = [tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model)]
        encoder += [
            tl.Select([0, 1, 1]),  # Save a copy of the mask
            tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
            tl.LayerNorm(epsilon=layer_norm_eps),
            tl.Residual(*feed_forward),
            tl.LayerNorm(epsilon=layer_norm_eps),
        ]

    encoder += [tl.Select([0], n_in=2)]  # Drop the mask

    pooler = [
        tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
        tl.Dense(d_model),
        tl.Tanh(),
    ]

    init_checkpoint = init_checkpoint if mode == 'train' else None
    bert = PretrainedBERT(embeddings + encoder + pooler,
                          init_checkpoint=init_checkpoint)

    if head is not None:
        bert = tl.Serial(bert, head())

    return bert
コード例 #9
0
    def __init__(self,
                 blocks,
                 loss_layer,
                 optimizer_fn,
                 n_devices=None,
                 memoize_jit=True):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
      memoize_jit: Whether to memoize JITed functions; this significantly speeds
        up XLA compilation of larger models, but it uses `repr(layer)` as keys
        to memoize so it could fail if two layers with different functionality
        had the same string representaion. We have not encountered such case
        yet so this is turned on by default, but consider turning it off or
        reviewing your model if you use custom layers and encounter a problem.
    """
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])
        self._n_steps_per_log = 100  # Log layers and stats every 100 steps.
        self._jit_memory = {} if memoize_jit else None

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn, f'fwd {repr(layer)}'),
            self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt,
                                              self._n_devices)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbo = _reverse_and_fbo_with_layer_and_opt(
                    layer, opt, self._n_devices)
                rev_and_fbos.append(
                    self._pjit(rev_and_fbo,
                               f'rev+bwd {repr(layer)}',
                               donate_argnums=(1, 2)))
            jit_std_fbo = self._pjit(std_fbo,
                                     f'bwd {repr(std_layer)}',
                                     donate_argnums=(1, 2))
            self._fbos.append((jit_std_fbo, rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss')
        self._loss_fbo = self._pjit(loss_fbo, donate_argnums=(1, 2))
コード例 #10
0
    def __init__(self, blocks, loss_layer, optimizer_fn, n_devices=None):
        """Creates a ReversibleSerialTrainer and the needed optimizers.

    This trainer performs updates equivalent to using the default Trainer on::

      tl.Serial(blocks + [loss_layer]).

    It is more memory-efficient though since weights are stored on CPU and only
    sent to accelerator layer-by-layer. Blocks are pairs consisting of a list
    of standard (arbitrary) layers and a list of reversible layers which help
    save memory thanks to being reversible.

    Args:
      blocks: A list of pairs of lists of standard and reversible layers.
      loss_layer: The final layer of the model; it can have trainable weights
        but should end with a loss: it is required to produce a scalar output.
      optimizer_fn: A function to create the optimizer, e.g., `optimizers.Adam`.
      n_devices: An optional integer, number of accelerator devices to use;
        by default, all available accelerators will be used.
    """
        # TODO(lukaszkaiser): remove these 2 lines once PR #4039 lands for JAX.
        if fastmath.is_backend(fastmath.Backend.JAX):
            jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access
        self._blocks = [(tl.Serial(std), rev) for (std, rev) in blocks]
        self._loss_layer = loss_layer
        self._optimizer_fn = optimizer_fn
        self._n_devices = n_devices or fastmath.device_count()
        self._n_layers = 1 + sum([len(revs) + 1 for (_, revs) in self._blocks])

        # Create accelerated versions of layers as pmaped/jited pure_fn.
        self._accelerated_layer_fns = fastmath.nested_map(
            lambda layer: self._pjit(layer.pure_fn), self._blocks)

        # Create per-layer optimizers and replicate opt_params.
        def _make_optimizer(layer):
            opt = optimizer_fn()
            opt.tree_init(layer.weights)
            return opt

        self._optimizers = fastmath.nested_map(_make_optimizer, self._blocks)
        self._replicated_opt_params = fastmath.nested_map(
            lambda opt: self._replicate(opt.opt_params), self._optimizers)

        self._loss_opt = _make_optimizer(loss_layer)
        self._replicated_loss_opt_params = self._replicate(
            self._loss_opt.opt_params)

        # Forward + backward + optimizer-update functions for all layers.
        # We call them in short FBO for "Forward + Backward + Optimizer update".
        # Reversible layers define a reverse_and_fbo function that also reverses.

        self._fbos = []
        for i, (std_layer, rev_layers) in enumerate(self._blocks):
            (std_opt, rev_opts) = self._optimizers[i]
            std_fbo = _fbo_with_layer_and_opt(std_layer, std_opt,
                                              self._n_devices)
            rev_and_fbos = []
            for layer, opt in zip(rev_layers, rev_opts):
                rev_and_fbos.append(
                    self._pjit(
                        _reverse_and_fbo_with_layer_and_opt(
                            layer, opt, self._n_devices)))
            self._fbos.append((self._pjit(std_fbo), rev_and_fbos))

        loss_fbo = _fbo_with_layer_and_opt(self._loss_layer, self._loss_opt,
                                           self._n_devices, 'loss')
        self._loss_fbo = self._pjit(loss_fbo)
コード例 #11
0
ファイル: Intro to Trax.py プロジェクト: giantianye/NLP
#

# In[12]:

# help(tl.Serial)
# help(tl.Parallel)

# In[13]:

# Serial combinator
serial = tl.Serial(
    tl.LayerNorm(),  # normalize input
    tl.Relu(),  # convert negative values to zero
    times_two,  # the custom layer you created above, multiplies the input recieved from above by 2

    ### START CODE HERE
    #     tl.Dense(n_units=2),  # try adding more layers. eg uncomment these lines
    #     tl.Dense(n_units=1),  # Binary classification, maybe? uncomment at your own peril
    #     tl.LogSoftmax()       # Yes, LogSoftmax is also a layer
    ### END CODE HERE
)

# Initialization
x = np.array([-2, -1, 0, 1, 2])  #input
serial.init(shapes.signature(x))  #initialising serial instance

print("-- Serial Model --")
print(serial, "\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
コード例 #12
0
def TransformerNoEncDecAttention(input_vocab_size,
                                 output_vocab_size=None,
                                 d_model=512,
                                 d_ff=2048,
                                 n_encoder_layers=6,
                                 n_decoder_layers=6,
                                 n_heads=8,
                                 dropout=0.1,
                                 dropout_shared_axes=None,
                                 max_len=2048,
                                 mode='train',
                                 ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    def PositionalEncoder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode),
            tl.PositionalEncoding(max_len=max_len),
        ]

    in_encoder = PositionalEncoder(input_vocab_size)
    out_encoder = (in_encoder if output_vocab_size is None else
                   PositionalEncoder(output_vocab_size))
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        transformer._EncoderBlock(
            d_model,
            d_ff,
            n_heads,
            dropout,  # pylint: disable=protected-access
            dropout_shared_axes,
            mode,
            ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = [
        transformer._DecoderBlock(
            d_model,
            d_ff,
            n_heads,
            dropout,  # pylint: disable=protected-access
            dropout_shared_axes,
            mode,
            ff_activation) for i in range(n_decoder_layers)
    ]

    # pylint: disable=protected-access
    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Simple encoder mask, doesn't contain extra dims.
        tl.Select([2, 0, 2], n_in=3),  #  tok_e vec_e tok_e tok_d tok_d
        tl.Fn(
            'EncoderMask',  # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0,
            n_out=1),

        # Decode.
        tl.Select([3, 1, 0, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder.
        tl.Select([1, 0]),  # vec_e svec_d mask_e tok_e tok_d
        ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
コード例 #13
0
ファイル: reformer.py プロジェクト: rizwandel/trax
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              pos_type='fixed-base',
              pos_axial_shape=(),
              pos_d_axial_embs=None,
              pos_start_from_zero_prob=1.0,
              pos_max_offset_to_add=0,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              loss_sparsity_type='mult',
              loss_sparsity=0,
              loss_d_lowrank=0,
              loss_sparsity_prob=None,
              attention_chunk_size=0,
              n_layers_forget=0,
              forget_dense=True,
              n_decoder_attention_layers=2,
              use_bfloat16=False,
              reversible_encoder=False,
              use_two_swaps_per_encoder_block=True,
              center_layernorm=True,
              half_before_layer=None,
              double_after_layer=None,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  If input_vocab_size is not None, this model expects an input pair: source,
  target. Otherwise, it expects a triple: embedded_source, mask, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    pos_start_from_zero_prob: how often to start from 0 during training,
          (if 1.0, we always start from position 0, if less, we randomize).
    pos_max_offset_to_add: maximum offset to add to positions during training
        when randomizing; this offset plus input length must still be less than
        max_len for all training examples.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    loss_sparsity_type: str, type of sparsity to used in loss layer. See
      SparseDenseWithOptions for options. None if no sparsity should be used.
    loss_sparsity: int, the sparsity for loss layer (if used)
    loss_d_lowrank: int, the dimensions for intermediate layer (if used)
    loss_sparsity_prob: float, the probability for sparse version of loss to be
      used. If None, only sparse version is used.
    attention_chunk_size: int, if > 0 run attention chunked at this size
    n_layers_forget: how often to have a forgetting block between layers
    forget_dense: whether to use Dense or no-op (Serial) as a forget layer.
    n_decoder_attention_layers: how many attention layers in a decoder block
    use_bfloat16: whether to use bfloat16 for weights (default: False)
    reversible_encoder: whether to be reversible through the encoder
    use_two_swaps_per_encoder_block: whether to allow even number of swaps in
      the encoder
    center_layernorm: whether to use centering in LayerNorm (default) or if
      to skip it, which is known as RMS normalization.
    half_before_layer: int, half d_model and d_ff before that layer
    double_after_layer: int, double d_model and d_ff after that layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # Set default dimensions for attention head key and value sizes.
    if (d_model / 2) % n_heads != 0:
        raise ValueError(
            f'n_heads ({n_heads}) must divide d_model/2 ({d_model/2})')
    if d_attention_key is None:
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        d_attention_value = d_model // n_heads

    # Set values of d_model, d_ff and d_qkv for the first stage.
    d_model1, d_ff1 = d_model, d_ff
    d_attention_key1, d_attention_value1 = d_attention_key, d_attention_value
    if half_before_layer:
        d_model1, d_ff1 = d_model / 2, d_ff / 2
        d_attention_key1 = d_attention_key / 2
        d_attention_value1 = d_attention_value / 2

    # Set values of d_model, d_ff and d_qkv for the final stage.
    d_model2, d_ff2 = d_model, d_ff
    d_attention_key2, d_attention_value2 = d_attention_key, d_attention_value
    if double_after_layer:
        d_model2, d_ff2 = d_model * 2, d_ff * 2
        d_attention_key2 = d_attention_key * 2
        d_attention_value2 = d_attention_value * 2

    # Vector embeddings.
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model1,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs,
            pos_start_from_zero_prob=pos_start_from_zero_prob,
            pos_max_offset_to_add=pos_max_offset_to_add,
            use_bfloat16=use_bfloat16))

    def _EncoderBlock():
        return EncoderBlock(
            d_model1,
            d_ff1,
            n_heads,
            encoder_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            use_two_swaps_per_block=use_two_swaps_per_encoder_block,
            mode=mode)

    def _Encoder():  # vec_e mask_e tok_e tok_d tok_d
        layers = [
            tl.ReversibleSelect([0, 0]),
            _ReversibleSerialForget(
                [_EncoderBlock() for _ in range(n_encoder_layers)], d_model1,
                n_layers_forget, forget_dense)
        ]
        if not reversible_encoder:
            layers += [
                _XYAvg(),
                tl.Dense(d_model1, use_bfloat16=use_bfloat16),
                tl.LayerNorm(),
            ]
        if mode == 'predict':
            return tl.Cache(tl.Serial(layers))
        else:
            return tl.Serial(layers)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        # Grow d_model, d_ff, and d_qkv if requested.
        d_m, d_f, d_k, d_v = d_model1, d_ff1, d_attention_key1, d_attention_value1
        if half_before_layer and layer_idx >= half_before_layer:
            d_m, d_f, d_k, d_v = d_model, d_ff, d_attention_key, d_attention_value
        if double_after_layer and layer_idx > double_after_layer:
            d_m, d_f, d_k, d_v = d_model2, d_ff2, d_attention_key2, d_attention_value2
        decoder_block = DecoderBlock(
            d_m,
            d_f,
            d_k,
            d_v,
            n_heads,
            attention_type=layer_attention_type,
            dropout=dropout,
            ff_activation=ff_activation,
            ff_dropout=ff_dropout,
            ff_use_sru=ff_use_sru,
            ff_chunk_size=ff_chunk_size,
            ff_sparsity=ff_sparsity,
            attention_chunk_size=attention_chunk_size,
            n_attention_layers=n_decoder_attention_layers,
            center_layernorm=center_layernorm,
            use_bfloat16=use_bfloat16,
            mode=mode)
        decoder_blocks.append(decoder_block)
        if half_before_layer and layer_idx == half_before_layer - 1:
            decoder_blocks.append(tl.ReversibleConcatenatePair())
        if double_after_layer and layer_idx == double_after_layer:
            decoder_blocks.append(tl.ReversibleConcatenatePair())

    def _Loss():
        return tl.SparseDenseWithOptions(output_vocab_size,
                                         d_input=d_model2,
                                         sparsity_type=loss_sparsity_type,
                                         sparsity=loss_sparsity,
                                         d_lowrank=loss_d_lowrank,
                                         prob_sparse=loss_sparsity_prob,
                                         use_bfloat16=use_bfloat16,
                                         mode=mode)

    def _enc_dec_concat():
        """Layers to merge encoder and decoder."""
        if reversible_encoder:
            return [
                tl.ReversibleSelect([0, 1, 4, 2,
                                     3]),  # v_e v_d mask_e tok_e tok_d
                t2.ConcatWithPadding2(mode=mode),  # v_ed v_ed tok_e tok_d
            ]
        else:
            return [
                tl.ReversibleSelect([0, 3, 1,
                                     2]),  # v_e v_d mask_e tok_e tok_d
                t2.ConcatWithPadding(mode=mode),  # v_ed tok_e tok_d
                tl.ReversibleSelect([0, 0]),  # v_ed v_ed tok_e tok_d
            ]

    def _inp_layers():
        if input_vocab_size is not None:
            return tl.AssertFunction(
                'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
                tl.Serial(  # tok_e tok_d
                    tl.Select([0, 0, 0, 1]),
                    tl.Parallel(
                        in_encoder,
                        [tl.PaddingMask(), _RemoveAxes12()
                         ])))  # vec_e mask_e tok_e tok_d
        else:
            # Input in this case is vec_e, mask_e, tok_d. Where all downstream
            # operations expect tok_e, we give it instead mask_e, expecting that
            # downstream ops only are looking for padding/not padding.
            return tl.AssertFunction(
                'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
                tl.Serial(  # vec_e mask_e tok_d
                    tl.Select([0, 1, 1, 2]),
                    tl.Parallel(in_encoder, [],
                                _AsTokenIDs())))  # vec_e mask_e tok_e tok_d

    # Assemble and return the model.
    return tl.Serial(
        _inp_layers(),  # vec_e mask_e tok_e tok_d
        tl.Select([0, 1, 2, 3, 3]),  # Copy decoder tokens for use in loss.

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel([], [], [], [tl.ShiftRight(mode=mode), out_encoder
                                 ]),  # vec_e mask_e tok_e vec_d tok_d

        # Predict mode doesn't work with padding in encoder. Raising an exception
        # in jitted function isn't possible, so the next best thing is to convert
        # every embedding to NaNs, so the user will get unmistakably wrong
        # results.
        (_ConvertToNaNsOnAnyZero() if mode == 'predict' else []),

        # Encode; then concat encoder and decoder, given encoder mask.
        _Encoder(),  # vec_e mask_e tok_e vec_d tok_d
        _enc_dec_concat(),

        # Run decoder blocks.
        _ReversibleSerialForget(decoder_blocks, d_model2, n_layers_forget,
                                forget_dense),  # vec_ed1 vec_ed2 tok_e tok_d
        _XYAvg(),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector,
        # then compute loss.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        t2.StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d
        _Loss(),  # vec_d tok_d
    )
コード例 #14
0
ファイル: reformer.py プロジェクト: rizwandel/trax
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train',
             pos_type=None,
             pos_axial_shape=None,
             pos_d_axial_embs=None,
             ff_use_sru=0,
             ff_chunk_size=0,
             ff_sparsity=0):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            pos_type=pos_type,
            pos_axial_shape=pos_axial_shape,
            pos_d_axial_embs=pos_d_axial_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     tl.SelfAttention,
                     dropout,
                     ff_activation,
                     ff_dropout,
                     mode=mode,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        _XYAvg(),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model,
                            d_ff,
                            n_heads,
                            dropout,
                            ff_activation,
                            ff_dropout,
                            mode,
                            ff_use_sru=ff_use_sru,
                            ff_chunk_size=ff_chunk_size,
                            ff_sparsity=ff_sparsity)
        for _ in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch(
            [],
            [tl.PaddingMask(), _RemoveAxes12()]),  # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        _XYAvg(),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
    )
コード例 #15
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
  """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Parallel([], tl.Drop()),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
コード例 #16
0
def extract_reversible_blocks(layers, loss_chunk_size=0):
    """Extracts blocks and loss layer for use with ReversibleSerialTrainer.

  Args:
    layers: a list of layers of a single layer to extract blocks from;
      should end with a loss, e.g., [model, loss] or tl.Serial(model, loss).
    loss_chunk_size: int, if > 0 creates a chunked loss layer to save memory
      in models with larger vocabulary; requires the last sublayers of loss
      are [Dense, LogSoftmax, _CrossEntropy, _WeightedMean] in that order.

  Returns:
    a pair (blocks, loss_layer) to use with ReversibleSerialTrainer.
  """
    def _flatten(l):
        """Flatten all Serial layers and sub(sub-...) layers into a list."""
        if isinstance(l, (list, tuple)):
            return [x for layer in l for x in _flatten(layer)]  # pylint: disable=g-complex-comprehension
        elif isinstance(l, tl.Serial):
            return _flatten(l.sublayers)
        else:
            return [l]

    # Extract standard and reversible layer blocks.
    blocks, std_layers, rev_layers = [], [], []
    for layer in _flatten(layers):
        if isinstance(layer, tl.ReversibleLayer):
            rev_layers.append(layer)
        elif not rev_layers:
            std_layers.append(layer)
        else:
            blocks.append((std_layers, rev_layers))
            std_layers, rev_layers = [], []
            std_layers.append(layer)
    if rev_layers:
        raise ValueError(
            'The final layer must be a standard loss, not reversible.')
    if loss_chunk_size > 0:
        # For now we only do chunking of [Dense, LogSoftmax, CrossEntopy, Mean]
        # Let's check that these are the last 4 layers.
        if len(std_layers) < 4:
            raise ValueError('Too short loss layer for chunking')
        # To check for Dense, remove the n_units part from name.
        name4 = std_layers[-4].name[:
                                    5]  # Just 'Dense' not e.g., 'Dense_32000'.
        last_4_names = ' '.join([name4] + [l.name for l in std_layers[-3:]])
        if last_4_names != 'Dense LogSoftmax _CrossEntropy _WeightedMean':
            raise ValueError(
                'Loss chunking only works with last layers being "Dense'
                ' LogSoftmax, _CrossEntropy, _WeightedMean" but got: ' +
                last_4_names)
        # Create chunked dense+logsoftmax+cross-entropy-loss.
        chunked_xent = tl.Chunk(tl.Serial(std_layers[-4:-1]), loss_chunk_size)

        # The chunked loss should operate on a merged batch dimension, e.g.,
        # including both length and batch size. Need to merge and un-merge later.
        def _reshape_to_batch_and_copy_targets(preds, targets):
            batched_preds = jnp.reshape(preds, [-1, preds.shape[-1]])
            batched_targets = jnp.reshape(targets, [-1])
            return batched_preds, batched_targets, targets

        def _reshape_xent_back(xent, targets):
            return jnp.reshape(xent, targets.shape)

        batched_xent = tl.Serial(
            tl.Fn('pre_xent_rebatch',
                  _reshape_to_batch_and_copy_targets,
                  n_out=3), chunked_xent,
            tl.Fn('after_xent_rebatch', _reshape_xent_back))
        loss_layer = tl.Serial(std_layers[:-4] + [batched_xent],
                               std_layers[-1])
    else:
        loss_layer = tl.Serial(std_layers)
    return blocks, loss_layer
コード例 #17
0
def ReZeroTransformer(input_vocab_size,
                      output_vocab_size=None,
                      d_model=512,
                      d_ff=2048,
                      n_encoder_layers=6,
                      n_decoder_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      max_len=2048,
                      mode='train',
                      ff_activation=tl.Relu):
    """Returns a ReZero transformer model.

  This model expects an input pair: source, target.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A ReZero transformer model as a layer that maps from a source, target pair
    to activations over a vocab set.
  """
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    # Positional encoding are not shared between encoder and decoder.
    # Since encoder doesn't run stepwise, we do not use predict mode there.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=encoder_mode)
    ]
    out_encoder = out_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=mode)
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)
        for i in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
    )
コード例 #18
0
ファイル: trainer_lib_test.py プロジェクト: stephenjfox/trax
 def model_fn(mode='train'):
     return tl.Serial(
         tl.Dropout(mode=mode, rate=0.1),
         tl.BatchNorm(mode=mode),
         models.MLP(layer_widths=(16, 16, n_classes),
                    mode=mode))
コード例 #19
0
def FeedForwardWithOptions(d_model,
                           d_ff,
                           dropout,
                           dropout_shared_axes,
                           ff_activation,
                           ff_dropout,
                           ff_chunk_size,
                           ff_use_sru,
                           ff_sparsity,
                           mode,
                           use_bfloat16=False,
                           ff_sparsity_type='1inN'):
  """Feed-Forward block with all the options.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    ff_activation: Type of activation function at the end of each block; must be
      an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: If `'train'`, each block will include dropout; else, it will pass all
      values through unaltered.
    use_bfloat16: whether to use bfloat16 for weights (default: False).
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`

  Returns:
    A list of layers which maps vectors to vectors.
  """
  if ff_use_sru:
    return [tl.SRU(d_model) for _ in range(ff_use_sru)]
  elif ff_sparsity and ff_sparsity_type == '1inN':
    if isinstance(ff_sparsity, tuple):
      n_elements_in_block, d_lowrank = ff_sparsity
    else:
      assert isinstance(ff_sparsity, int)
      n_elements_in_block, d_lowrank = ff_sparsity, d_ff // ff_sparsity
    ff = tl.SparseFF(
        d_ff,
        n_elements_in_block=n_elements_in_block,
        d_lowrank=d_lowrank,
        mode=mode)
    if ff_chunk_size < 1:
      chunked_ff = ff
    else:
      chunked_ff = tl.BatchLeadingAxes(tl.Chunk(tl.Serial(ff), ff_chunk_size))
    return [
        tl.LayerNorm(), chunked_ff,
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
    ]
  elif ff_sparsity and ff_sparsity_type == 'Block':
    return [
        tl.LayerNorm(),
        tl.BlockSparseFF(d_ff, num_experts=ff_sparsity, mode=mode),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode)
    ]
  else:
    return [
        ChunkedFeedForward(d_model, d_ff, dropout, ff_activation, ff_dropout,
                           ff_chunk_size, use_bfloat16, mode)
    ]
コード例 #20
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if math.backend_name() == 'jax':
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        if not axial_pos_shape:
            positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                        dropout=dropout,
                                                        mode=mode)
        else:
            assert d_axial_pos_embs is not None
            positional_encoding = tl.AxialPositionalEncoding(
                shape=axial_pos_shape,
                d_embs=d_axial_pos_embs,
                dropout_broadcast_dims=tuple(range(1,
                                                   len(axial_pos_shape) + 1)),
                dropout=dropout,
                mode=mode)

        return [
            tl.Embedding(d_model, vocab_size),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
            positional_encoding,
        ]

    # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
    # position embeddings between the encoder and decoder if output_vocab_size is
    # None. This isn't supported here because (a) Trax shares weights by sharing
    # layer instances, but we need two separate instances to have mode == 'eval'
    # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
    # not work if its sublayers participate in any weight sharing.

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, encoder_attention_type, dropout,
                     ff_activation, ff_dropout, mode)
        for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # tok_e mask_e tok_e tok_d tok_d
        in_encoder,  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # tok_e mask_e tok_e tok_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        tl.Branch([], _MaskOfRightShiftedArray()
                  ),  # stok_d mask_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d mask_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given their masks.
        tl.Select([2, 0, 3, 1]),  # svec_d mask_d vec_e mask_e tok_e tok_d
        _ConcatWithPadding(),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.ReversibleSerial(decoder_blocks),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        _StripFromConcatenateWithPadding(),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
コード例 #21
0
ファイル: trainer_lib.py プロジェクト: dioptre/trax
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
コード例 #22
0
def Transformer2(input_vocab_size,
                 output_vocab_size=None,
                 d_model=512,
                 d_ff=2048,
                 n_encoder_layers=6,
                 n_decoder_layers=6,
                 n_heads=8,
                 dropout=0.1,
                 dropout_shared_axes=None,
                 max_len=2048,
                 mode='train',
                 ff_activation=tl.Relu,
                 ff_dropout=0.1,
                 ff_chunk_size=0,
                 ff_use_sru=0,
                 ff_sparsity=0,
                 ff_sparsity_type='1inN',
                 attention_chunk_size=0,
                 encoder_attention_type=tl.Attention,
                 n_encoder_attention_layers=1,
                 decoder_attention_type=tl.CausalAttention,
                 n_decoder_attention_layers=2,
                 pos_type=None,
                 pos_axial_shape=None,
                 pos_d_axial_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    n_encoder_attention_layers: int, within each encoder block, how many
      attention layers to have.
    decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    n_decoder_attention_layers: int, within each decoder block, how many
      attention layers to have.
    pos_type: string, the type of positional embeddings to use.
    pos_axial_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    pos_d_axial_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match pos_axial_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           pos_type=pos_type,
                                           pos_axial_shape=pos_axial_shape,
                                           pos_d_axial_embs=pos_d_axial_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, encoder_attention_type,
                        n_encoder_attention_layers)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    decoder_blocks = [
        ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, decoder_attention_type,
                        n_decoder_attention_layers)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Simple encoder mask, doesn't contain extra dims.
        tl.Select([2, 0, 2], n_in=3),  #  tok_e vec_e tok_e tok_d tok_d
        tl.Fn(
            'EncoderMask',  # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0,
            n_out=1),

        # Decode.
        tl.Select([3, 1, 0, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder.
        tl.Select([1, 0]),  # vec_e svec_d mask_e tok_e tok_d
        ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
    )
コード例 #23
0
ファイル: trainer_lib_test.py プロジェクト: xibelly/trax
 def model_fn(mode='train'):
     return tl.Serial(
         tl.Dropout(mode=mode, rate=0.1), tl.BatchNorm(mode=mode),
         models.MLP(d_hidden=16,
                    n_output_classes=n_classes,
                    mode=mode))
コード例 #24
0
ファイル: reformer.py プロジェクト: dioptre/trax
def ReformerLM(vocab_size,
               d_model=512,
               d_ff=2048,
               d_attention_key=64,
               d_attention_value=64,
               n_layers=6,
               n_heads=8,
               dropout=0.1,
               max_len=2048,
               attention_type=tl.SelfAttention,
               axial_pos_shape=(),
               d_axial_pos_embs=None,
               ff_activation=tl.FastGelu,
               ff_use_sru=0,
               ff_chunk_size=0,
               ff_sparsity=0,
               mode='train'):
    """Reversible transformer language model (only uses a decoder, no encoder).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    attention_type: class: attention class to use, such as SelfAttention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    mode: str: 'train', 'eval', or 'predict'

  Returns:
    the layer.
  """
    positional_encoding = PositionalEncoding(mode, dropout, max_len,
                                             axial_pos_shape, d_axial_pos_embs)

    positional_embedder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        positional_embedder,
        tl.Dup(),
        tl.ReversibleSerial(decoder_blocks),
        tl.Concatenate(),
        # TODO(kitaev): Test whether dropout should go before or after the
        # LayerNorm, and whether dropout broadcasting is needed here.
        tl.LayerNorm(),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Dense(vocab_size),
        tl.LogSoftmax(),
    )
コード例 #25
0
ファイル: transformer.py プロジェクト: yaoshuyin/trax
def TransformerLM(vocab_size,
                  d_model=512,
                  d_ff=2048,
                  n_layers=6,
                  n_heads=8,
                  max_len=2048,
                  dropout=0.1,
                  dropout_shared_axes=None,
                  mode='train',
                  ff_activation=tl.Relu):
    """Returns a Transformer language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  This model uses only the decoder part of the overall Transformer.

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder block
        will include dropout; else, it will pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len, mode=mode)
    ]

    decoder_blocks = [
        # pylint: disable=g-complex-comprehension
        _DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # tokens (or chunked tuple of tokens)
        tl.ShiftRight(mode=mode),  # toks
        positional_encoder,  # vecs
        decoder_blocks,  # vecs
        tl.LayerNorm(),  # vecs
        tl.Dense(vocab_size),  # vecs
    )
コード例 #26
0
ファイル: reformer.py プロジェクト: dioptre/trax
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if fastmath.is_backend(fastmath.Backend.JAX):
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
        # TODO(kitaev): axial positional encoding is better for very long sequences.
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout,
                                                    mode=mode)
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
            positional_encoding,
        ]

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    in_encoder = PositionalEncoder(input_vocab_size,
                                   mode='eval' if mode == 'predict' else mode)
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
    out_encoder = PositionalEncoder(output_vocab_size, mode)

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     tl.SelfAttention,
                     dropout,
                     ff_activation,
                     ff_dropout,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                            ff_dropout, mode) for _ in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
コード例 #27
0
ファイル: transformer.py プロジェクト: yaoshuyin/trax
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                max_len=2048,
                dropout=0.1,
                dropout_shared_axes=None,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if None, then input and target integers (token IDs) are assumed to come
        from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
        block will include dropout; else, it will pass all values through
        unaltered.
    ff_activation: Type of activation function at the end of each
        encoder/decoder block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    # Positional encodings are not shared between encoder and decoder.
    # Since encoder doesn't run stepwise, we do not use predict mode there.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=encoder_mode)
    ]
    out_encoder = out_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=mode)
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)
        for i in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
    )
コード例 #28
0
ファイル: reformer.py プロジェクト: dioptre/trax
def Reformer2(input_vocab_size,
              output_vocab_size=None,
              d_model=512,
              d_ff=2048,
              d_attention_key=None,
              d_attention_value=None,
              n_encoder_layers=6,
              n_decoder_layers=6,
              n_heads=8,
              dropout=0.1,
              max_len=2048,
              encoder_attention_type=tl.SelfAttention,
              encoder_decoder_attention_type=tl.SelfAttention,
              axial_pos_shape='fixed-base',
              d_axial_pos_embs=None,
              ff_activation=tl.Relu,
              ff_use_sru=0,
              ff_chunk_size=0,
              ff_dropout=None,
              ff_sparsity=0,
              n_layers_forget=0,
              mode='train'):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    n_layers_forget: how often to have a forgetting block between layers
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    # The current API for custom gradients assumes that a layer must be
    # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
    # masks on the stack. This causes jax to error, even though the so-called
    # "gradient" wrt the masks is never actually computed.
    # TODO(kitaev): remove this hack.
    if fastmath.is_backend(fastmath.Backend.JAX):
        jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

    # Set default dimensions for attention head key and value sizes.
    if d_attention_key is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_key = d_model // n_heads
    if d_attention_value is None:
        if d_model % n_heads != 0:
            raise ValueError(
                f'n_heads ({n_heads}) must divide d_model ({d_model})')
        d_attention_value = d_model // n_heads

    # Vector embeddings.
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    def PositionalEnc(mode):
        return PositionalEncoding(mode, dropout, max_len, axial_pos_shape,
                                  d_axial_pos_embs)

    # Mode 'predict' means that the decoder should be run one token at a time.
    # The encoder only ever runs over full sequences, which is why it's switched
    # to 'eval' mode instead.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [PositionalEnc(encoder_mode)]
    out_encoder = out_embedder + [PositionalEnc(mode)]
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     encoder_attention_type,
                     dropout=dropout,
                     ff_activation=ff_activation,
                     ff_dropout=ff_dropout,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity,
                     mode=mode) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([  # vec_e mask_e tok_e tok_d tok_d
        tl.Dup(),  # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
        _ReversibleSerialForget(encoder_blocks, d_model, n_layers_forget),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.Dense(d_model),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    decoder_blocks = []

    if isinstance(encoder_decoder_attention_type, (tuple, list)):
        assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
    else:
        encoder_decoder_attention_type = [encoder_decoder_attention_type]
    for layer_idx in range(n_decoder_layers):
        layer_attention_type = encoder_decoder_attention_type[
            layer_idx % len(encoder_decoder_attention_type)]
        decoder_block = DecoderBlock(d_model,
                                     d_ff,
                                     d_attention_key,
                                     d_attention_value,
                                     n_heads,
                                     attention_type=layer_attention_type,
                                     dropout=dropout,
                                     ff_activation=ff_activation,
                                     ff_dropout=ff_dropout,
                                     ff_use_sru=ff_use_sru,
                                     ff_chunk_size=ff_chunk_size,
                                     ff_sparsity=ff_sparsity,
                                     mode=mode)
        decoder_blocks.append(decoder_block)

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 0, 1, 1]),  # tok_e tok_e tok_e tok_d tok_d

        # Embed in and out tokens; done together as weights may be shared.
        tl.Parallel(
            in_encoder,
            [],
            [],  # vec_e tok_e tok_e vec_d tok_d
            [tl.ShiftRight(mode=mode), out_encoder]),
        tl.Parallel([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                         # vec_e mask_e tok_e vec_d tok_d

        # Encode.
        encoder,  # vec_e mask_e tok_e vec_d tok_d

        # Decode.
        tl.Select([3, 0, 1, 2]),  #  vec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder, given their masks.
        tl.Select([1, 0]),  # vec_e vec_d mask_e tok_e tok_d
        _ConcatWithPadding(),  # vec_ed tok_e tok_d

        # Run (encoder and) decoder blocks.
        tl.Dup(),  # vec_ed1 vec_ed2 tok_e tok_d
        _ReversibleSerialForget(
            decoder_blocks, d_model,
            n_layers_forget),  # vec_ed1 vec_ed2 tok_e tok_d
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        _StripFromConcatenateWithPadding(),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
コード例 #29
0
def ConfigurableTransformerEncoder(vocab_size,
                                   n_classes=10,
                                   d_model=512,
                                   d_ff=2048,
                                   n_layers=6,
                                   n_heads=8,
                                   max_len=2048,
                                   dropout=0.1,
                                   dropout_shared_axes=None,
                                   mode='train',
                                   ff_activation=tl.Relu,
                                   ff_dropout=0.1,
                                   ff_chunk_size=0,
                                   ff_use_sru=0,
                                   ff_sparsity=0,
                                   ff_sparsity_type='1inN',
                                   attention_chunk_size=0,
                                   attention_type=tl.Attention,
                                   axial_pos_shape=None,
                                   d_axial_pos_embs=None):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor should
      be an integer in `range(vocab_size)`. These integers typically represent
      token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
      classification.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
      residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
      pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder block;
      must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use for the encoder part.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        PositionalEncoder(mode, dropout, max_len, axial_pos_shape,
                          d_axial_pos_embs)
    ]

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                     mode, ff_activation, ff_dropout, ff_chunk_size,
                     ff_use_sru, ff_sparsity, ff_sparsity_type,
                     attention_chunk_size, attention_type)
        for i in range(n_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    )
コード例 #30
0
 def test_const_div(self):
     layer = tl.Serial(ReturnConst(np.array([3, 6, 9, 12])), DivideBy(3))
     y = layer(())
     self.assertEqual(as_list(y), [1, 2, 3, 4])