Beispiel #1
0
def TransformerEncoder(vocab_size=vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu,
                       EncoderBlock=EncoderBlock):
    """
    Returns a Transformer encoder model.
    The input to the model is a tensor of tokens.
  
    Args:
        vocab_size (int): vocab size. Defaults to vocab_size.
        n_classes (int): how many classes on output. Defaults to 10.
        d_model (int): depth of embedding. Defaults to 512.
        d_ff (int): depth of feed-forward layer. Defaults to 2048.
        n_layers (int): number of encoder/decoder layers. Defaults to 6.
        n_heads (int): number of attention heads. Defaults to 8.
        dropout (float): dropout rate (how much to drop out). Defaults to 0.1.
        dropout_shared_axes (int): axes on which to share dropout mask. Defaults to None.
        max_len (int): maximum symbol length for positional encoding. Defaults to 2048.
        mode (str): 'train' or 'eval'. Defaults to 'train'.
        ff_activation (function): the non-linearity in feed-forward layer. Defaults to tl.Relu.
        EncoderBlock (function): Returns the encoder block. Defaults to EncoderBlock.
  
    Returns:
        trax.layers.combinators.Serial: A Transformer model as a layer that maps
        from a tensor of tokens to activations over a set of output classes.
    """

    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    ### START CODE HERE (REPLACE INSTANCES OF 'None' WITH YOUR CODE) ###

    # Use the function `EncoderBlock` (implemented above) and pass in the parameters over `n_layers`
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     dropout,
                     dropout_shared_axes,
                     mode,
                     ff_activation,
                     FeedForwardBlock=FeedForwardBlock)
        for _ in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Encode
        tl.Branch(
            # Use `positional_encoder`
            positional_encoder,
            # Use trax padding mask
            tl.PaddingMask(),
        ),
        # Use `encoder_blocks`
        encoder_blocks,
        # Use select layer
        tl.Select([0], n_in=2),
        # Use trax layer normalization
        tl.LayerNorm(),
        # Map to output categories.
        # Use trax mean. set axis to 1
        tl.Mean(axis=1),
        # Use trax Dense using `n_classes`
        tl.Dense(n_classes),
        # Use trax log softmax
        tl.LogSoftmax(),
    )
Beispiel #2
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type, dropout,
          ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # tok_e mask_e tok_e tok_d tok_d
      in_encoder,                      # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),                  # tok_e tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # tok_e mask_e tok_e tok_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),                # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given their masks.
      tl.Select([2, 0, 3, 1]),          # svec_d mask_d vec_e mask_e tok_e tok_d
      _ConcatWithPadding(),                        # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      tl.ReversibleSerial(decoder_blocks),         # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),          # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),                 # vec_d tok_d
      tl.LogSoftmax(),                             # vec_d tok_d
  )
Beispiel #3
0
 def test_noop_dup(self):
     layer = tl.Branch([], tl.Dup())
     x = np.array([1, 2, 3])
     ys = layer(x)
     self.assertEqual(as_list(ys), [[1, 2, 3], [1, 2, 3], [1, 2, 3]])
Beispiel #4
0
 def test_one_sublayer(self):
     layer = tl.Branch(DivideBy(0.5))
     x = np.array([1, 2, 3])
     ys = layer(x)
     self.assertEqual(as_list(ys), [2, 4, 6])
Beispiel #5
0
 def test_printing_sublayers(self):
     layer = tl.Branch(tl.Add(), tl.Add())
     expected_result = 'Branch_in2_out2[\n  Add_in2\n  Add_in2\n]'
     self.assertEqual(expected_result, str(layer))
Beispiel #6
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 has_weights=False,
                 nontrainable_param_map=None,
                 id_to_mask=None,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, self._n_devices, rng = (self._init_host_and_devices(
            n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._has_weights = has_weights
        self._id_to_mask = id_to_mask
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        loss_fn = loss_fn(has_weights=has_weights, id_to_mask=id_to_mask)
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()

        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        # If the inputs are a tuple/list, add [None] (batch) to each element.
        if self._inputs.input_shape and isinstance(self._inputs.input_shape[0],
                                                   (list, tuple)):
            model_input_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.input_shape)
        else:  # Otherwise just add [None] to the input shape.
            model_input_shape = tuple([None] + list(self._inputs.input_shape))
        # Same for targets.
        if self._inputs.target_shape and isinstance(
                self._inputs.target_shape[0], (list, tuple)):
            model_target_shape = tuple(
                tuple([None] + list(shape))
                for shape in self._inputs.target_shape)
        else:
            model_target_shape = tuple([None] +
                                       list(self._inputs.target_shape))
        # Change all None to 1 in input and target shape.
        model_input_shape = math.nested_map(lambda x: x or 1,
                                            model_input_shape)
        model_target_shape = math.nested_map(lambda x: x or 1,
                                             model_target_shape)

        def new_opt_state_and_model_state(shape_dtype, rng):
            """Returns optimizer and model states suitable for training a model."""
            # Combine inputs and targets on the stack.
            shapes, dtypes = shape_dtype
            input_signature = tuple(
                ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))
            # We need to create a new model instance and not reuse `model_train` here,
            # because `m.initialize` puts cached parameter values in `m` and hence the
            # next call of `m.initialize` will give wrong results.
            m = tl.Serial(model(mode='train'), loss_fn)
            m._set_rng_recursive(rng)  # pylint: disable=protected-access
            weights, state = m.init(input_signature)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if _is_jit_init():
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = math.jit(
                new_opt_state_and_model_state, static_argnums=(0, ))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
                self._inputs.example_shape_dtype, init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [
            self._metrics_dict[m](has_weights=self._has_weights,
                                  id_to_mask=self._id_to_mask)
            for m in self._metrics
        ]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        # TODO(pkozakowski): "Learning rate schedules" are currently able to control
        # control all optimizer parameters and model state, so let's rename them
        # accordingly.
        self._lr_schedule = lr_schedule

        if nontrainable_param_map is None:
            nontrainable_param_map = {}
        self._nontrainable_param_map = nontrainable_param_map

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._lr_fn = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Beispiel #7
0
def TransformerNoEncDecAttention(input_vocab_size,
                                 output_vocab_size=None,
                                 d_model=512,
                                 d_ff=2048,
                                 n_encoder_layers=6,
                                 n_decoder_layers=6,
                                 n_heads=8,
                                 dropout=0.1,
                                 dropout_shared_axes=None,
                                 max_len=2048,
                                 mode='train',
                                 ff_activation=tl.Relu):
  """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  def PositionalEncoder(vocab_size):  # tokens --> vectors
    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                    mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder = tl.Serial(
      in_encoder,
      encoder_blocks,
      tl.LayerNorm()
  )
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = [
      _DecoderBlock(d_model, d_ff, n_heads,
                    dropout, dropout_shared_axes, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),            # tok_e tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),    # tok_e mask_e tok_e tok_d tok_d
      encoder,                            # vec_e mask_e tok_e tok_d tok_d

      # Simple encoder mask, doesn't contain extra dims.
      tl.Select([2, 0, 2], n_in=3),       # tok_e vec_e tok_e tok_d tok_d
      _MaskOfRightShiftedArray(
          n_shifts=0),                    # mask_e vec_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 1, 0, 2]),          #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),         # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder.
      tl.Select([2, 0, 3, 1]),          # vec_e svec_d mask_e mask_d tok_e tok_d
      _ConcatWithPadding(),             # vec_ed tok_e tok_d

      # Decoder blocks with causal attention
      decoder_blocks,                     # vec_ed tok_e tok_d
      tl.LayerNorm(),                     # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),               # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),    # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),        # vec_d tok_d
      tl.LogSoftmax(),                    # vec_d tok_d
  )
Beispiel #8
0
def FunnelTransformer(vocab_size,
                      d_model=512,
                      d_ff=2048,
                      encoder_segment_lengths=(2, 2, 2),
                      n_decoder_blocks=2,
                      n_heads=8,
                      max_len=2048,
                      dropout=0.1,
                      dropout_shared_axes=None,
                      mode='train',
                      ff_activation=tl.Relu,
                      pool_layer=tl.AvgPool,
                      pool_size=(2,),
                      separate_cls=True):
  """Returns a Full Funnel Transformer, that can be used for example for BERT.

  This model outputs token-level categorical distributions over all vocab:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions over `vocab_size` categories for each token; shape is
      (batch_size, sequence_length, vocab_size).


  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    encoder_segment_lengths: Tuple, where each element denotes the number of
        transformer encoder blocks preceding a funnel transformer block.
        There is no funnel block after the last sequence of encoder blocks,
        therefore the total number of blocks in the model is equal to
        `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`.
    n_decoder_blocks: Number of transformer blocks in the upsampling decoder.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling in each of the
        funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
        embeddings of the first token (`cls` from BERT paper) and only final
        embedding of this token is used for categorization - the rest are
        discarded. If `False`, each token from the beginning is pooled and
        all embeddings are averaged and mapped to output categories like in
        original `TransformerEncoder` model.
  """
  assert encoder_segment_lengths

  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      tl.PositionalEncoding(max_len=max_len)]

  n_encoder_segments = len(encoder_segment_lengths)

  encoder_blocks_before_first_pooling = [
      _EncoderBlock(d_model, d_ff, n_heads, dropout,
                    dropout_shared_axes, mode, ff_activation)
      for _ in range(encoder_segment_lengths[0])]
  encoder_blocks_from_first_pooling = []

  for i in range(1, n_encoder_segments):
    # Building i'th segment

    # Add funnel block between segments
    encoder_blocks_from_first_pooling.append(
        _FunnelBlock(d_model, d_ff, n_heads, dropout,
                     dropout_shared_axes, mode,
                     ff_activation, pool_layer,
                     pool_size=pool_size, strides=pool_size,
                     separate_cls=separate_cls))

    for _ in range(encoder_segment_lengths[i]):
      # Create segment_size encoder blocks
      encoder_blocks_from_first_pooling.append(
          _EncoderBlock(d_model, d_ff, n_heads, dropout,
                        dropout_shared_axes, mode, ff_activation))

  decoder_blocks = [_EncoderBlock(d_model, d_ff, n_heads, dropout,
                                  dropout_shared_axes, mode, ff_activation)
                    for _ in range(n_decoder_blocks)]

  total_pool_size = pool_size[0] ** (len(encoder_segment_lengths) - 1)

  # Assemble and return the model.
  return tl.Serial(                               # toks
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks_before_first_pooling,        # vecs masks
      tl.Select([0, 1, 0, 1]),
      # vecs masks residual = vecs old_masks
      encoder_blocks_from_first_pooling,          # vecs masks residual masks
      tl.Select([0, 2, 3]),                       # vecs residual masks
      tl.Parallel(
          # residual from first segment is taken before
          # normalization, so apply it now
          None, tl.LayerNorm(), None),            # vecs norm(residual) masks
      _Upsampler(total_pool_size, separate_cls),  # vecs masks
      decoder_blocks,
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),
      tl.Dense(vocab_size),
  )
Beispiel #9
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train',
             axial_pos_shape=None,
             d_axial_pos_embs=None,
             ff_use_sru=0,
             ff_chunk_size=0,
             ff_sparsity=0):
    """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(
            input_vocab_size,
            d_model,
            mode,
            dropout,
            [-2],  # dropout_shared_axes
            max_len,
            output_vocab_size=output_vocab_size,
            axial_pos_shape=axial_pos_shape,
            d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        EncoderBlock(d_model,
                     d_ff,
                     n_heads,
                     tl.SelfAttention,
                     dropout,
                     ff_activation,
                     ff_dropout,
                     mode=mode,
                     ff_use_sru=ff_use_sru,
                     ff_chunk_size=ff_chunk_size,
                     ff_sparsity=ff_sparsity) for _ in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial([
        in_encoder,
        tl.Dup(),
        tl.ReversibleSerial(encoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
        tl.LayerNorm(),
    ])
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    encoder_decoder_blocks = [
        EncoderDecoderBlock(d_model,
                            d_ff,
                            n_heads,
                            dropout,
                            ff_activation,
                            ff_dropout,
                            mode,
                            ff_use_sru=ff_use_sru,
                            ff_chunk_size=ff_chunk_size,
                            ff_sparsity=ff_sparsity)
        for _ in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d
        tl.Branch([], [
            tl.PaddingMask(),
            tl.Fn('Squeeze', lambda x: jnp.squeeze(x, (1, 2)), n_out=1)
        ]),
        #                                     # tok_e mask  tok_d .....

        # Encode.
        encoder,  # vec_e  mask tok_d .....

        # Decode.
        tl.Select([2, 0, 1]),  # tok_d vec_e mask .....
        tl.ShiftRight(mode=mode),  # tok_d vec_e mask .....
        out_encoder,  # vec_d vec_e mask .....
        tl.Dup(),  # vec_d1 vec_d2 vec_e mask .....
        tl.ReversibleSerial(encoder_decoder_blocks),
        tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),  # vec_d vec_e mask .....
        tl.LayerNorm(),  # vec_d vec_e mask .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d .....
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Beispiel #10
0
def _FunnelBlock(d_model, d_ff, n_heads,
                 dropout, dropout_shared_axes, mode, ff_activation,
                 pool_layer, pool_size, strides, separate_cls):
  """Internal funnel block. Returns a list of layers implementing it.

  The input is an activation tensor.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling;
        should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
          embeddings of the first token (`cls` from BERT paper).
  Returns:
      A list of layers that maps (activations, mask) to (activations', mask).
  """
  pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls)
  mask_pooling = MaskPool(pool_size, strides, separate_cls)

  attention = tl.AttentionQKV(d_model, n_heads=n_heads, dropout=dropout,
                              mode=mode)
  hidden_dropout = tl.Dropout(
      rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

  return [                                          # h, mask
      tl.LayerNorm(),                               # h, mask
      tl.Branch(pooling, None),                     # h', h, mask
      tl.Residual(
          tl.Select([0, 1, 1, 2]),                  # h', h, h, mask
          attention,                                # attn, mask
          tl.Parallel(None, mask_pooling),          # attn, mask'
          hidden_dropout                            # attn, mask'
      ),                                            # funnel_activations, mask'
      tl.Residual(
          tl.LayerNorm(),
          feed_forward,
          hidden_dropout,
      )
  ]
Beispiel #11
0
def FunnelTransformerEncoder(vocab_size,
                             n_classes=10,
                             d_model=512,
                             d_ff=2048,
                             encoder_segment_lengths=(2, 2, 2),
                             n_heads=8,
                             max_len=2048,
                             dropout=0.1,
                             dropout_shared_axes=None,
                             mode='train',
                             ff_activation=tl.Relu,
                             pool_layer=tl.AvgPool,
                             pool_size=(2,),
                             strides=(2,),
                             separate_cls=True):
  """Returns a Funnel Encoder.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
        classification.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    encoder_segment_lengths: Tuple, where each element denotes the number of
        transformer encoder blocks preceding a funnel transformer block.
        There is no funnel block after the last sequence of encoder blocks,
        therefore the total number of blocks in the model is equal to
        `sum(encoder_segment_lengths) + len(encoder_segment_lengths) - 1`.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling in each of the
        funnel blocks; should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
        embeddings of the first token (`cls` from BERT paper) and only final
        embedding of this token is used for categorization - the rest are
        discarded. If `False`, each token from the beginning is pooled and
        all embeddings are averaged and mapped to output categories like in
        original `TransformerEncoder` model.
  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
  assert encoder_segment_lengths

  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      tl.PositionalEncoding(max_len=max_len)]

  encoder_blocks = []
  n_encoder_segments = len(encoder_segment_lengths)

  for i in range(n_encoder_segments):
    # Building i'th segment
    for _ in range(encoder_segment_lengths[i]):
      # Create segment_size encoder blocks
      encoder_blocks.append(
          _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation))

    # If not last segment, add funnel block
    if i != n_encoder_segments - 1:
      encoder_blocks.append(
          _FunnelBlock(d_model, d_ff, n_heads, dropout,
                       dropout_shared_axes, mode,
                       ff_activation, pool_layer, pool_size,
                       strides, separate_cls))

  cls_pooling = SelectFirst() if separate_cls else tl.Mean(axis=1)

  # Assemble and return the model.
  return tl.Serial(                               # toks
      # Encode.
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks,                             # vecs masks
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),                             # vecs

      # Map to output categories.
      cls_pooling,                                # cls
      tl.Dense(n_classes),                        # cls
  )
Beispiel #12
0
  def __init__(self, model, loss_fn, optimizer, lr_schedule, inputs,
               output_dir=None, random_seed=None, n_devices=None,
               checkpoints_at=None, should_save_checkpoints=True,
               should_write_summaries=True, has_weights=False,
               nontrainable_param_map=None, mask_id=None, metrics=None):

    self._is_chief, self._n_devices, rng = (
        self._init_host_and_devices(n_devices, random_seed))
    self._should_save_checkpoints = should_save_checkpoints and self._is_chief
    self._checkpoints_at = checkpoints_at or []
    self._should_write_summaries = should_write_summaries

    self._has_weights = has_weights
    self._mask_id = mask_id
    self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
    loss_fn = loss_fn(has_weights=has_weights, mask_id=mask_id)
    inputs = inputs(self._n_devices)
    self._inputs = inputs

    # Initialize the learning rate to a dummy value. It will be set in reset().
    opt = optimizer(learning_rate=0.0)

    # Setup the model.
    model_train = model(mode='train')
    model_predict_eval = model(mode='eval')

    # Setup state.
    rng, init_rng = jax_random.split(rng)
    self._rngs = np.stack(jax_random.split(rng, self._n_devices))
    first_shape = inputs.input_shape[0]
    # If the inputs are a tuple/list, add [None] (batch) to each element.
    if isinstance(first_shape, (list, tuple)):
      model_input_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.input_shape)
      model_target_shape = tuple(
          tuple([None] + list(shape)) for shape in inputs.target_shape)
    else:  # Otherwise just add [None] to the input shape.
      model_input_shape = tuple([None] + list(inputs.input_shape))
      model_target_shape = tuple([None] + list(inputs.target_shape))
    # Change all None to 1 in input and target shape.
    model_input_shape = backend.nested_map(lambda x: x or 1, model_input_shape)
    model_target_shape = backend.nested_map(lambda x: x or 1,
                                            model_target_shape)

    def new_opt_state_and_model_state(input_shape, input_dtype, target_shape,
                                      target_dtype, rng):
      """Returns optimizer and model states suitable for training a model."""
      # Combine inputs and targets on the stack.
      if not isinstance(input_dtype, (list, tuple)):
        input_dtype = [input_dtype]
        input_shape = [input_shape]
      if not isinstance(target_dtype, (list, tuple)):
        target_dtype = [target_dtype]
        target_shape = [target_shape]
      dtypes = list(input_dtype) + list(target_dtype)
      shapes = list(input_shape) + list(target_shape)
      if self._has_weights:
        shapes += list(target_shape)
        dtypes += [np.float32 for _ in target_dtype]
      input_signature = tuple(ShapeDtype(s, d)
                              for (s, d) in zip(shapes, dtypes))
      # We need to create a new model instance and not reuse `model_train` here,
      # because `m.initialize` puts cached parameter values in `m` and hence the
      # next call of `m.initialize` will give wrong results.
      m = tl.Serial(model(mode='train'), loss_fn)
      m._set_rng_recursive(rng)  # pylint: disable=protected-access
      weights, state = m.init(input_signature)
      (slots, opt_params) = opt.tree_init(weights)
      return (OptState(weights, slots, opt_params), state)

    if _is_jit_init():
      # JIT parameter initialization to avoid memory fragmentation
      new_opt_state_and_model_state = backend.jit(new_opt_state_and_model_state,
                                                  static_argnums=(0, 1, 2, 3))
    self._new_opt_state_and_model_state = (
        lambda: new_opt_state_and_model_state(  # pylint: disable=g-long-lambda
            model_input_shape, self._inputs.input_dtype,
            model_target_shape, self._inputs.target_dtype, init_rng))

    # Arrange and initialize metrics layers.
    self._metrics = list(sorted(self._metrics_dict.keys()))
    metrics_layers = [self._metrics_dict[m](has_weights=self._has_weights,
                                            mask_id=self._mask_id)
                      for m in self._metrics]
    metrics_in_parallel = tl.Branch(*metrics_layers)
    # TODO(lukaszkaiser): clean this up once layer API stabilizes.
    # For now, we need to initialize metric layers somehow, so here we go.
    # We assume that they do not have any parameters, so this is a dummy.
    dummy_shapes = ((1, 2), (1,), (1,)) if self._has_weights else ((1, 2), (1,))
    dummy_signature = tuple(ShapeDtype(s) for s in dummy_shapes)
    metrics_in_parallel._set_rng_recursive(init_rng)  # pylint: disable=protected-access
    m_weights, m_state = metrics_in_parallel.init(dummy_signature)
    self._metrics_weights = self._for_n_devices(m_weights)
    self._metrics_state = self._for_n_devices(m_state)

    # Jit model_predict and update so they're fast.
    self._jit_eval = _jit_predict_fn(
        model_predict_eval, metrics_in_parallel, self._n_devices)
    self._jit_update_fn = _jit_update_fn(
        model_train, loss_fn, opt, self._n_devices)

    self._model_train = model_train
    self._model_predict_eval = model_predict_eval
    self._loss_fn = loss_fn
    # TODO(pkozakowski): "Learning rate schedules" are currently able to control
    # control all optimizer parameters and model state, so let's rename them
    # accordingly.
    self._lr_schedule = lr_schedule

    if nontrainable_param_map is None:
      nontrainable_param_map = {}
    self._nontrainable_param_map = nontrainable_param_map

    # Those fields will be set in reset().
    self._output_dir = None
    self._train_sw = None
    self._eval_sw = None
    self._history = None
    self._lr_fn = None
    self._opt_state = None
    self._step = None
    self._model_state = None

    if output_dir is not None:
      self.reset(output_dir)
Beispiel #13
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                max_len=2048,
                dropout=0.1,
                dropout_shared_axes=None,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if None, then input and target integers (token IDs) are assumed to come
        from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
        block will include dropout; else, it will pass all values through
        unaltered.
    ff_activation: Type of activation function at the end of each
        encoder/decoder block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    def Embedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(vocab_size, d_model),
            tl.Dropout(rate=dropout,
                       shared_axes=dropout_shared_axes,
                       mode=mode),
        ]

    in_embedder = Embedder(input_vocab_size)
    out_embedder = (in_embedder if output_vocab_size is None else
                    Embedder(output_vocab_size))

    # Positional encodings are not shared between encoder and decoder.
    # Since encoder doesn't run stepwise, we do not use predict mode there.
    encoder_mode = 'eval' if mode == 'predict' else mode
    in_encoder = in_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=encoder_mode)
    ]
    out_encoder = out_embedder + [
        tl.PositionalEncoding(max_len=max_len, mode=mode)
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_encoder_layers)
    ]

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    encoder_decoder_blocks = [
        _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)
        for i in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks ..... .....
        encoder,  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
    )
Beispiel #14
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       max_len=2048,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
        classification.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
        block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of `Layer`.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
def ConfigurableTransformerEncoder(vocab_size,
                                   n_classes=10,
                                   d_model=512,
                                   d_ff=2048,
                                   n_layers=6,
                                   n_heads=8,
                                   max_len=2048,
                                   dropout=0.1,
                                   dropout_shared_axes=None,
                                   mode='train',
                                   ff_activation=tl.Relu,
                                   ff_dropout=0.1,
                                   ff_chunk_size=0,
                                   ff_use_sru=0,
                                   ff_sparsity=0,
                                   ff_sparsity_type='1inN',
                                   attention_chunk_size=0,
                                   attention_type=tl.Attention,
                                   axial_pos_shape=None,
                                   d_axial_pos_embs=None):
  """Returns a Transformer encoder merged with an N-way categorization head.

  This model performs text categorization:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 2 tensor representing a batch of log-probability
      distributions over N categories; shape is (batch_size, `n_classes`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor should
      be an integer in `range(vocab_size)`. These integers typically represent
      token IDs from a vocabulary-based tokenizer.
    n_classes: Final dimension of the output tensors, representing N-way
      classification.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
      residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'train'`, each encoder block will include dropout; else, it will
      pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder block;
      must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use for the encoder part.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model that maps strings (conveyed via token IDs) to
    probability-like activations over a range of output classes.
  """
  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      PositionalEncoder(
          mode, dropout, max_len, axial_pos_shape, d_axial_pos_embs)
  ]

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                   ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
                   ff_sparsity, ff_sparsity_type,
                   attention_chunk_size, attention_type)
      for i in range(n_layers)
  ]
  # pylint: enable=g-complex-comprehension

  # Assemble and return the model.
  return tl.Serial(                               # toks
      # Encode.
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks,                             # vecs masks
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),                             # vecs

      # Map to output categories.
      tl.Mean(axis=1),                            # vecs
      tl.Dense(n_classes),                        # vecs
  )
Beispiel #16
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    # TODO(kitaev): dropout=0.0 for tl.PositionalEncoding matches trax
    # Transformer, but may not be the right option in general.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=0.0, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        # TODO(kitaev): BroadcastedDropout?
        tl.Dropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_encoder_layers)]

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch(
          in_encoder, [tl.PaddingMask(),
                       tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]
          ),                                # vec_e  mask  tok_d .....
      tl.Dup(),                             # vec_e1 vec_e2 mask tok_d .....
      tl.ReversibleSerial(encoder_blocks),  # vec_e1 vec_e2 mask tok_d .....
      # The two sets of activations need to be reduced to one, in this case by
      # averaging them. Note that ReformerLM concatenates instead. Various
      # options (concat, average, add, keep only one, etc.) seem to perform
      # similarly. We don't concatenate here because we want exact parameter
      # parity with the standard Transformer.
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_e  mask tok_d .....
      tl.LayerNorm(),                       # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(),                      # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
def ConfigurableTransformer(input_vocab_size,
                            output_vocab_size=None,
                            d_model=512,
                            d_ff=2048,
                            n_encoder_layers=6,
                            n_decoder_layers=6,
                            n_heads=8,
                            max_len=2048,
                            dropout=0.1,
                            dropout_shared_axes=None,
                            mode='train',
                            ff_activation=tl.Relu,
                            ff_dropout=0.1,
                            ff_chunk_size=0,
                            ff_use_sru=0,
                            ff_sparsity=0,
                            ff_sparsity_type='1inN',
                            attention_chunk_size=0,
                            encoder_attention_type=tl.Attention,
                            encoder_decoder_attention_type=tl.CausalAttention,
                            axial_pos_shape=None,
                            d_axial_pos_embs=None):
  """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
      should be an integer in `range(vocab_size)`. These integers typically
      represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
      if None, then input and target integers (token IDs) are assumed to come
      from the same vocabulary.
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each encoder
      and decoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within an encoder/decoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'predict'`, use fast inference. If `'train'`, each encoder/decoder
      block will include dropout; else, it will pass all values through
      unaltered.
    ff_activation: Type of activation function at the end of each
      encoder/decoder block; must be an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    encoder_decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
  in_encoder, out_encoder, output_vocab_size = (
      EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model,
          mode,
          dropout,
          dropout_shared_axes,
          max_len,
          output_vocab_size=output_vocab_size,
          axial_pos_shape=axial_pos_shape,
          d_axial_pos_embs=d_axial_pos_embs)
  )

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                   ff_activation, ff_dropout, ff_chunk_size, ff_use_sru,
                   ff_sparsity, ff_sparsity_type,
                   attention_chunk_size, encoder_attention_type)
      for i in range(n_encoder_layers)
  ]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  # pylint: disable=g-complex-comprehension
  encoder_decoder_blocks = [
      EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                          mode, ff_activation, ff_dropout, ff_chunk_size,
                          ff_use_sru, ff_sparsity, ff_sparsity_type,
                          attention_chunk_size, encoder_decoder_attention_type)
      for i in range(n_decoder_layers)
  ]
  # pylint: enable=g-complex-comprehension

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),               # tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),    # tok_e masks ..... .....
      encoder,                            # vec_e ..... ..... .....

      # Decode.
      tl.Select([2, 1, 0]),               # tok_d masks vec_e .....
      tl.ShiftRight(mode=mode),           # tok_d ..... ..... .....
      out_encoder,                        # vec_d ..... ..... .....
      tl.Branch(
          [], tl.EncoderDecoderMask()),   # vec_d masks ..... .....
      encoder_decoder_blocks,             # vec_d masks ..... .....
      tl.LayerNorm(),                     # vec_d ..... ..... .....

      # Map to output vocab.
      tl.Select([0], n_in=3),             # vec_d tok_d
      tl.Dense(output_vocab_size),        # vec_d .....
  )
Beispiel #18
0
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
    """Returns an RNN language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Embedding depth throughout the model.
    n_layers: Number of RNN layers.
    rnn_cell: Type of RNN cell; must be a subclass of `Layer`.
    rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell
        state.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout.
    mode: If `'predict'`, use fast inference; if `'train'` apply dropout.

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """

    if n_layers != 2:  # TODO(jonni): Remove n_layers arg, if it can't vary?
        raise ValueError(f'Number of layers must be set to 2; instead got'
                         f' {n_layers}.')

    def MultiRNNCell():
        """Multi-layer RNN cell."""
        return tl.Serial(
            tl.Parallel([], tl.Split(n_items=n_layers)),
            tl.SerialWithSideOutputs(
                [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
            tl.Parallel([], tl.Concatenate(n_items=n_layers)))

    zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
        depth_multiplier=n_layers * rnn_cell_d_state_multiplier)

    return tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Branch([], zero_state),
        tl.Scan(MultiRNNCell(), axis=1, mode=mode),
        tl.Select([0], n_in=2),  # Drop RNN state.
        tl.Dense(vocab_size),
    )
Beispiel #19
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       dropout_shared_axes=None,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu):
  """Returns a Transformer-style encoder.

  For each item in a batch, this model performs a sequence-to-sequence mapping:

    - input: sequence of integers, usually token id's from a fixed-size
      vocabulary -- integers in `range(M)`, where `M` is the vocabulary
      size.

    - output:  same-length sequence of N-dimensional vectors, where each vector
      can be interpreted as a log-probability distribution over N discrete
      categories.

  Args:
    vocab_size: "Vocabulary size" -- input integer id's must be in
        `range(vocab_size)`. Id's typically come from preprocessing text data
        with a vocabulary-based tokenizer.
    n_classes: Size/depth of the output vectors, intended for an N-way
        classification task.
    d_model: The basic embedding size (vector depth) of the model. This is the
        vector size used by the initial embedding layer and at many intermediate
        points in the model.
    d_ff: Vector depth (typically greater than `d_model`) used in the
        feed-forward (`Dense`) layer of each encoder block.
    n_layers: Number of encoder blocks. Each encoder block includes attention,
        dropout, residual, feed-forward (`Dense`), and activation layers.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within an encoder block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
    max_len: Maximum symbol length for positional encoding.
    mode: If `'train'`, each encoder block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: The activation function (layer) at the end of each encoder
        block.

  Returns:
    A Transformer model as a layer that maps from token id's to activations
    over a set of output classes.
  """
  positional_encoder = [
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
      tl.PositionalEncoding(max_len=max_len)]

  encoder_blocks = [
      _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                    mode, ff_activation)
      for i in range(n_layers)]

  # Assemble and return the model.
  return tl.Serial(                               # toks
      # Encode.
      tl.Branch(
          positional_encoder, tl.PaddingMask()),  # vecs masks
      encoder_blocks,                             # vecs masks
      tl.Select([0], n_in=2),                     # vecs
      tl.LayerNorm(),                             # vecs

      # Map to output categories.
      tl.Mean(axis=1),                            # vecs
      tl.Dense(n_classes),                        # vecs
      tl.LogSoftmax(),                            # vecs
  )
Beispiel #20
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=D_MODEL,
                d_ff=D_FF,
                n_encoder_layers=N_LAYERS,
                n_decoder_layers=N_LAYERS,
                n_heads=N_HEADS,
                max_len=MAX_SEQUENCE_LENGTH,
                dropout=DROPOUT_RATE,
                dropout_shared_axes=DROPOUT_SHARED_AXES,
                mode=MODE,
                ff_activation=FF_ACTIVATION_TYPE):
    """Returns a full Transformer model.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: Array representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length),
          where sequence_length <= ``max_len``. Array elements are integers in
          ``range(input_vocab_size)``, and 0 values mark padding positions.

        - target: Array representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length),
          where sequence_length <= ``max_len``. Array elements are integers in
          ``range(output_vocab_size)``, and 0 values mark padding positions.

    - output: 3-D array of raw activations with last/innermost dimension of
      ``output_vocab_size``, suitable for decoding into a batch of token
      strings; shape is (batch_size, sequence_length, ``vocab_size``).

  An example use would be to translate (tokenized) sentences from English to
  German.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    output_vocab_size: If specified, gives the vocabulary size for the targets;
        if ``None``, then input and target integers (token IDs) are assumed to
        come from the same vocabulary.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    n_encoder_layers: Number of encoder blocks.
    n_decoder_layers: Number of decoder blocks.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder/decoder blocks. The same rate is
        also used for attention dropout in encoder/decoder blocks.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``)
        is a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If ``'predict'``, use fast inference. If ``'train'``, each
        encoder/decoder block will include dropout; else, it will pass all
        values through unaltered.
    ff_activation: Type of activation function at the end of each
        encoder/decoder block; must be an activation-type subclass of
        :py:class:`Layer`.

  Returns:
    A Transformer model as a layer that maps from a source-target tokenized
    text pair to activations over a vocab set.
  """
    # Avoid 'predict' mode in encoder, since encoder doesn't run stepwise.
    encoder_mode = 'eval' if mode == 'predict' else mode

    # Share embedding weights if no separate output vocab size.
    in_embedder = tl.Embedding(input_vocab_size, d_model)
    if output_vocab_size is None:
        out_embedder = in_embedder
        output_vocab_size = input_vocab_size
    else:
        out_embedder = tl.Embedding(output_vocab_size, d_model)

    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    def _EncBlock():
        return _EncoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)

    def _Encoder():
        encoder = tl.Serial(
            in_embedder,
            _Dropout(),
            tl.PositionalEncoding(max_len=max_len, mode=encoder_mode),
            [_EncBlock() for _ in range(n_encoder_layers)],
            tl.LayerNorm(),
        )
        return tl.Cache(encoder) if mode == 'predict' else encoder

    def _EncDecBlock():
        return _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout,
                                    dropout_shared_axes, mode, ff_activation)

    # Input to model is encoder-side tokens and decoder-side tokens: tok_d, tok_e
    # Model output is decoder-side vectors and decoder-side tokens: vec_d  tok_d
    return tl.Serial(
        tl.Select([0, 1, 1]),  # Copies decoder tokens for use in loss.

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e masks tok_d tok_d
        _Encoder(),

        # Decode.
        tl.Select([2, 1, 0]),  # Re-orders inputs: tok_d masks vec_e .....
        tl.ShiftRight(mode=mode),
        out_embedder,
        _Dropout(),
        tl.PositionalEncoding(max_len=max_len, mode=mode),
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        [_EncDecBlock() for _ in range(n_decoder_layers)],
        tl.LayerNorm(),
        tl.Select([0], n_in=3),  # Drops masks and encoding vectors.

        # Map vectors to match output vocab size.
        tl.Dense(output_vocab_size),
    )
Beispiel #21
0
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
                 output_dir=None,
                 random_seed=None,
                 n_devices=None,
                 checkpoints_at=None,
                 should_save_checkpoints=True,
                 should_write_summaries=True,
                 metrics=None,
                 checkpoint_highest=None,
                 checkpoint_lowest=None):

        self._is_chief, _, self._n_devices, rng = (
            training.init_host_and_devices(n_devices, random_seed))
        self._should_save_checkpoints = should_save_checkpoints and self._is_chief
        self._checkpoints_at = checkpoints_at or []
        self._should_write_summaries = should_write_summaries
        if not output_dir:
            self._should_save_checkpoints = False
            self._should_write_summaries = False
        self._checkpoint_highest = checkpoint_highest
        self._checkpoint_lowest = checkpoint_lowest
        self._metrics_dict = metrics if metrics is not None else _DEFAULT_METRICS
        # Inputs is either an Inputs instance or a function that returns it.
        self._inputs = inputs
        if callable(
                inputs):  # If we pass a function, e.g., through gin, call it.
            self._inputs = inputs()
        # Initialize the learning rate to a dummy value. It will be set in reset().
        opt = optimizer(learning_rate=0.0)

        # Setup the model.
        model_train = model(mode='train')
        model_predict_eval = model(mode='eval')
        self._model_with_loss = tl.Serial(model_train, loss_fn)

        # Setup state.
        rng, init_rng = jax_random.split(rng)
        self._rngs = np.stack(jax_random.split(rng, self._n_devices))
        shapes, dtypes = self._inputs.example_shape_dtype
        input_signature = tuple(
            ShapeDtype(s, d) for (s, d) in zip(shapes, dtypes))

        def new_opt_state_and_model_state(rng):
            """Returns optimizer and model states suitable for training a model."""
            weights, state = self._model_with_loss.init(input_signature,
                                                        rng=rng)
            (slots, opt_params) = opt.tree_init(weights)
            return (OptState(weights, slots, opt_params), state)

        if fastmath.is_backend(fastmath.Backend.JAX):
            # JIT parameter initialization to avoid memory fragmentation
            new_opt_state_and_model_state = (
                fastmath.jit(new_opt_state_and_model_state))
        self._new_opt_state_and_model_state = (
            lambda: new_opt_state_and_model_state(init_rng))

        # Arrange and initialize metrics layers.
        self._metrics = list(sorted(self._metrics_dict.keys()))
        metrics_layers = [self._metrics_dict[m] for m in self._metrics]
        metrics_in_parallel = tl.Branch(*metrics_layers)
        metrics_in_parallel.rng = init_rng
        example_signature = tuple(
            ShapeDtype(s, d)
            for (s, d) in zip(*self._inputs.example_shape_dtype))
        model_predict_eval.init(example_signature)
        self._input_signature = example_signature
        output_signature = model_predict_eval.output_signature(
            example_signature)
        m_weights, m_state = metrics_in_parallel.init(output_signature)
        self._metrics_weights = self._for_n_devices(m_weights)
        self._metrics_state = self._for_n_devices(m_state)

        # Jit model_predict and update so they're fast.
        self._jit_eval = _jit_predict_fn(model_predict_eval,
                                         metrics_in_parallel, self._n_devices)
        self._jit_update_fn = _jit_update_fn(model_train, loss_fn, opt,
                                             self._n_devices)

        self._model_train = model_train
        self._model_predict_eval = model_predict_eval
        self._loss_fn = loss_fn
        self._lr_schedule = lr_schedule

        # Those fields will be set in reset().
        self._output_dir = None
        self._train_sw = None
        self._eval_sw = None
        self._history = None
        self._opt_state = None
        self._step = None
        self._model_state = None
        self.reset(output_dir)
Beispiel #22
0
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=D_MODEL,
                       d_ff=D_FF,
                       n_layers=N_LAYERS,
                       n_heads=N_HEADS,
                       max_len=MAX_SEQUENCE_LENGTH,
                       dropout=DROPOUT_RATE,
                       dropout_shared_axes=DROPOUT_SHARED_AXES,
                       mode=MODE,
                       ff_activation=FF_ACTIVATION_TYPE):
    """Returns a Transformer encoder suitable for N-way classification.

  This model maps tokenized text to N-way (``n_classes``) activations:

    - input: Array representing a batch of text strings via token IDs plus
      padding markers; shape is (batch_size, sequence_length), where
      sequence_length <= ``max_len``. Array elements are integers in
      ``range(vocab_size)``, and 0 values mark padding positions.

    - output: Array representing a batch of raw (non-normalized) activations
      over ``n_classes`` categories; shape is (batch_size, ``n_classes``).

  Args:
    vocab_size: Input vocabulary size -- each element of the input array
        should be an integer in ``range(vocab_size)``. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    n_classes: Last/innermost dimension of output arrays, suitable for N-way
        classification.
    d_model: Last/innermost dimension of activation arrays at most points in
        the model, including the initial embedding output.
    d_ff: Last/innermost dimension of special (typically wider)
        :py:class:`Dense` layer in the feedforward part of each encoder block.
    n_layers: Number of encoder blocks. Each block includes attention, dropout,
        residual, layer-norm, feedforward (:py:class:`Dense`), and activation
        layers.
    n_heads: Number of attention heads.
    max_len: Maximum symbol length for positional encoding.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within encoder blocks. The same rate is also
        used for attention dropout in encoder blocks.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (``dropout_shared_axes=(0,1)``)
        is a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If ``'train'``, each encoder block will include dropout; else, it
        will pass all values through unaltered.
    ff_activation: Type of activation function at the end of each encoder
        block; must be an activation-type subclass of :py:class:`Layer`.

  Returns:
    A Transformer model that maps strings (conveyed by token IDs) to
    raw (non-normalized) activations over a range of output classes.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    def _EncBlock():
        return _EncoderBlock(d_model, d_ff, n_heads, dropout,
                             dropout_shared_axes, mode, ff_activation)

    return tl.Serial(
        tl.Branch([],
                  tl.PaddingMask()),  # Creates masks from copy of the tokens.
        tl.Embedding(vocab_size, d_model),
        _Dropout(),
        tl.PositionalEncoding(max_len=max_len),
        [_EncBlock() for _ in range(n_layers)],
        tl.Select([0], n_in=2),  # Drops the masks.
        tl.LayerNorm(),
        tl.Mean(axis=1),
        tl.Dense(n_classes),
    )
Beispiel #23
0
 def test_default_name(self):
     layer = tl.Branch(tl.Add(), DivideBy(0.5))
     self.assertIn('Branch', str(layer))
Beispiel #24
0
def Transformer2(input_vocab_size,
                 output_vocab_size=None,
                 d_model=512,
                 d_ff=2048,
                 n_encoder_layers=6,
                 n_decoder_layers=6,
                 n_heads=8,
                 dropout=0.1,
                 dropout_shared_axes=None,
                 max_len=2048,
                 mode='train',
                 ff_activation=tl.Relu,
                 axial_pos_shape=None,
                 d_axial_pos_embs=None):
  """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  in_encoder, out_encoder, output_vocab_size = (
      ct.EmbeddingAndPositionalEncodings(
          input_vocab_size,
          d_model,
          mode,
          dropout,
          dropout_shared_axes,
          max_len,
          output_vocab_size=output_vocab_size,
          axial_pos_shape=axial_pos_shape,
          d_axial_pos_embs=d_axial_pos_embs)
  )

  encoder_blocks = [
      transformer._EncoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder = tl.Serial(
      in_encoder,
      encoder_blocks,
      tl.LayerNorm()
  )
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = [
      transformer._DecoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # pylint: disable=protected-access
  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),          # tok_e tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
      encoder,                          # vec_e mask_e tok_e tok_d tok_d

      # Simple encoder mask, doesn't contain extra dims.
      tl.Select([2, 0, 2], n_in=3),     #  tok_e vec_e tok_e tok_d tok_d
      tl.Fn('EncoderMask',              # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0, n_out=1),

      # Decode.
      tl.Select([3, 1, 0, 2]),          #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),         # stok_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder.
      tl.Select([1, 0]),                # vec_e svec_d mask_e tok_e tok_d
      ConcatWithPadding(mode=mode),     # vec_ed tok_e tok_d

      # Decoder blocks with causal attention
      decoder_blocks,                   # vec_ed tok_e tok_d
      tl.LayerNorm(),                   # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),      # vec_d tok_d
      tl.LogSoftmax(),                  # vec_d tok_d
  )
Beispiel #25
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, tl.SelfAttention, dropout, ff_activation,
          ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([
      in_encoder,
      tl.Dup(),
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),                 # tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                     # tok_e mask  tok_d .....

      # Encode.
      encoder,                              # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(mode=mode),             # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),    # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
Beispiel #26
0
def BERTPretrainingHead(n_classes):
    return tl.Branch(BERTClassifierHead(n_classes), BERTMLMHead()),
Beispiel #27
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    def PositionalEncoder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(d_model, vocab_size),
            tl.Dropout(rate=dropout, mode=mode),
            tl.PositionalEncoding(max_len=max_len),
        ]

    in_encoder = PositionalEncoder(input_vocab_size)
    out_encoder = (in_encoder if output_vocab_size is None else
                   PositionalEncoder(output_vocab_size))
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode, ff_activation)
        for i in range(n_encoder_layers)
    ]

    encoder_decoder_blocks = [
        _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                             ff_activation) for i in range(n_decoder_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch(in_encoder, tl.PaddingMask()),  # vec_e masks ..... .....
        encoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(),  # tok_d ..... ..... .....
        out_encoder,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        encoder_decoder_blocks,  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Select([0], n_in=3),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Beispiel #28
0
def Transformer2(input_vocab_size,
                 output_vocab_size=None,
                 d_model=512,
                 d_ff=2048,
                 n_encoder_layers=6,
                 n_decoder_layers=6,
                 n_heads=8,
                 dropout=0.1,
                 dropout_shared_axes=None,
                 max_len=2048,
                 mode='train',
                 ff_activation=tl.Relu,
                 ff_dropout=0.1,
                 ff_chunk_size=0,
                 ff_use_sru=0,
                 ff_sparsity=0,
                 ff_sparsity_type='1inN',
                 attention_chunk_size=0,
                 encoder_attention_type=tl.Attention,
                 n_encoder_attention_layers=1,
                 decoder_attention_type=tl.CausalAttention,
                 n_decoder_attention_layers=2,
                 axial_pos_shape=None,
                 d_axial_pos_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    n_encoder_attention_layers: int, within each encoder block, how many
      attention layers to have.
    decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    n_decoder_attention_layers: int, within each decoder block, how many
      attention layers to have.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           axial_pos_shape=axial_pos_shape,
                                           d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, encoder_attention_type,
                        n_encoder_attention_layers)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    decoder_blocks = [
        ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, decoder_attention_type,
                        n_decoder_attention_layers)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Simple encoder mask, doesn't contain extra dims.
        tl.Select([2, 0, 2], n_in=3),  #  tok_e vec_e tok_e tok_d tok_d
        tl.Fn(
            'EncoderMask',  # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0,
            n_out=1),

        # Decode.
        tl.Select([3, 1, 0, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder.
        tl.Select([1, 0]),  # vec_e svec_d mask_e tok_e tok_d
        ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
        tl.LogSoftmax(),  # vec_d tok_d
    )
Beispiel #29
0
 def test_add_div(self):
     layer = tl.Branch(tl.Add(), DivideBy(0.5))
     xs = [np.array([1, 2, 3]), np.array([10, 20, 30])]
     ys = layer(xs)
     self.assertEqual(as_list(ys), [[11, 22, 33], [2, 4, 6]])
Beispiel #30
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                         ff_activation) for i in range(n_encoder_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode,
                           ff_activation) for i in range(n_decoder_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Serial(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Branch(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  # toks_d masks vecs_e
        tl.ShiftRight(),  # toks_d ..... ......
        out_embed,  # vecs_d ..... ......
        tl.Branch([], tl.EncoderDecoderMask()),  # vecs_d masks ......
        encoder_decoder_stack,  # vecs_d masks vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )