Ejemplo n.º 1
0
def BERTClassifierHead(n_classes):
    return tl.Serial([
        tl.Select([0], n_in=2),
        tl.Dense(
            n_classes,
            kernel_initializer=tl.RandomNormalInitializer(0.02),
            bias_initializer=tl.RandomNormalInitializer(1e-6),
        ),
        tl.LogSoftmax(),
    ])
Ejemplo n.º 2
0
    def __init__(self, residual_layers):
        self.compute_residual = tl.Serial(  # x1_or_y1, x2,           ...
            tl.Select([1, 0, 1]),  # x2, x1_or_y1, x2,       ...
            tl.Parallel([], [],
                        residual_layers),  # x2, x1_or_y1, residual, ...
            tl.Select([2, 1, 0]),  # residual, x1_or_y1, x2, ...
        )

        self.n_preserve = self.compute_residual.n_out - 2
        parallel_preserve = [[]] * self.n_preserve

        layers = [
            self.compute_residual,
            tl.Parallel(tl.Add(), *parallel_preserve)
        ]
        super(ReversibleHalfResidual, self).__init__(layers)

        self.subtract_top = tl.Parallel(tl.SubtractTop(), *parallel_preserve)
        self.reverse_layers = [self.compute_residual, self.subtract_top]
Ejemplo n.º 3
0
def BERTRegressionHead():
    return tl.Serial([
        tl.Select([0], n_in=2),
        tl.Dense(
            1,
            kernel_initializer=tl.RandomNormalInitializer(0.02),
            bias_initializer=tl.RandomNormalInitializer(1e-6),
        ),
        tl.Fn('RemoveAxis', lambda x: np.squeeze(x, axis=1))
    ])
Ejemplo n.º 4
0
def awr_weight_stat(stat_name, stat_fn, beta, preprocess_layer):
  # Select just the advantages if preprocess layer is not given.
  preprocess = tl.Select([1]) if preprocess_layer is None else preprocess_layer
  return tl.Serial([
      preprocess,
      tl.Fn(
          'AWRWeight' + stat_name.capitalize(),
          lambda x: stat_fn(awr_weights(x, beta)),
      ),
  ])
Ejemplo n.º 5
0
def ReZeroTransformerEncoder(vocab_size,
                             n_classes=10,
                             d_model=512,
                             d_ff=2048,
                             n_layers=6,
                             n_heads=8,
                             dropout=0.1,
                             dropout_shared_axes=None,
                             max_len=2048,
                             mode='train',
                             ff_activation=tl.Relu):
    """Returns a ReZero transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A ReZero transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    positional_encoder = [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len)
    ]

    encoder_blocks = [
        _EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                      mode, ff_activation) for i in range(n_layers)
    ]

    # Assemble and return the model.
    return tl.Serial(  # toks
        # Encode.
        tl.Branch(positional_encoder, tl.PaddingMask()),  # vecs masks
        encoder_blocks,  # vecs masks
        tl.Select([0], n_in=2),  # vecs
        tl.LayerNorm(),  # vecs

        # Map to output categories.
        tl.Mean(axis=1),  # vecs
        tl.Dense(n_classes),  # vecs
    )
Ejemplo n.º 6
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True)

    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # vec_d masks vec_e
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        ResidualZero(
            tl.LayerNorm(),
            feed_forward,  # vec_d masks vec_e
            _Dropout(),
        ),
    ]
Ejemplo n.º 7
0
def _RelativeDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                          mode, ff_activation, context_bias_layer,
                          location_bias_layer, total_pooling):
  """Returns a list of layers that implements a Transformer encoder block.

  The input to the block is a pair, (activations, mask), where the mask was
  created from the original source tokens to prevent attending to the padding
  part of the input.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.
    context_bias_layer: Global context bias from Transformer XL's attention.
    location_bias_layer: Global location bias from Transformer XL's attention.
    total_pooling: The combined pool size of previously used funnel blocks.

  Returns:
    A list of layers that maps (activations, att_vecs, mask) to
                               (activations, att_vecs, mask).
  """
  attention = RelativeAttentionLMLayer(
      d_model, context_bias_layer, location_bias_layer,
      total_pooling,
      n_heads=n_heads, dropout=dropout, mode=mode)

  feed_forward = _FeedForwardBlock(
      d_model, d_ff, dropout, dropout_shared_axes, mode, ff_activation)

  dropout_ = tl.Dropout(
      rate=dropout, shared_axes=dropout_shared_axes, mode=mode)

  return [
      tl.Residual(               # vecs
          tl.LayerNorm(),
          tl.Select([0, 0, 0]),
          attention,
          dropout_,
      ),                         # vecs
      tl.Residual(
          feed_forward
      ),                         # vecs
  ]
Ejemplo n.º 8
0
 def _inp_layers():
     if input_vocab_size is not None:
         return tl.AssertFunction(
             'bl,br->bld,bl,bl,br',  # b: batch, l/r: enc/dec length, d: vec depth
             tl.Serial(  # tok_e tok_d
                 tl.Select([0, 0, 0, 1]),
                 tl.Parallel(
                     in_encoder,
                     [tl.PaddingMask(), _RemoveAxes12()
                      ])))  # vec_e mask_e tok_e tok_d
     else:
         # Input in this case is vec_e, mask_e, tok_d. Where all downstream
         # operations expect tok_e, we give it instead mask_e, expecting that
         # downstream ops only are looking for padding/not padding.
         return tl.AssertFunction(
             'blf,bl,br->bld,bl,bl,br',  # f: in-feature depth, d: out-vector depth
             tl.Serial(  # vec_e mask_e tok_d
                 tl.Select([0, 1, 1, 2]),
                 tl.Parallel(in_encoder, [],
                             _AsTokenIDs())))  # vec_e mask_e tok_e tok_d
Ejemplo n.º 9
0
Archivo: rse.py Proyecto: yliu45/trax
def BenesBlock(d_model, dropout, mode):
  def bit_sequence(inputs):
    seq_length = inputs.shape[1]
    n_bits = np.int32(np.log(seq_length - 1) / np.log(2.0)) + 1
    return jnp.arange(0, n_bits)
  return tl.Serial(
      tl.Dup(),
      tl.Fn('BitSeq', bit_sequence, n_out=1),
      tl.Scan(_ForwardStep(d_model, dropout, mode)),
      tl.Scan(_BackwardStep(d_model, dropout, mode)),
      tl.Select([1]),
  )
Ejemplo n.º 10
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                         ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout, mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode)

    basic_causal_attention = tl.BasicCausalAttention(d_model,
                                                     n_heads=n_heads,
                                                     dropout=dropout,
                                                     mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode,
                                     ff_activation)

    return [  # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            basic_causal_attention,  # vec_d masks .....
            _Dropout(),  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        tl.Residual(feed_forward  # vec_d masks vec_e
                    ),
    ]
Ejemplo n.º 11
0
 def ConditionedBlock(current_layer_num):
   return tl.Serial(
       # stack: embedding, n_layers_to_keep
       tl.Select([1, 0, 1]),  # n_layers_to_keep, embedding, n_layers_to_keep
       tl.Cond(
           # if n_layers_to_keep > current_layer_num
           LargerThan(float(current_layer_num)),
           # then: run block
           tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
               d_model, d_ff, n_heads, dropout, [], mode, ff_activation)),
           # else: run noop
           tl.Serial()
           )
       # stack: embedding, n_layers_to_keep
       )
Ejemplo n.º 12
0
 def ConditionedBlock(current_layer_num):
   return tl.Serial(
       # stack: embedding, n_layers_to_keep
       tl.Select([1, 0, 1]),  # n_layers_to_keep, embedding, n_layers_to_keep
       tl.Cond(
           # if random() > skip_fraction OR layer not in skip_mode ...
           LargerThan(skip_fraction if skip_mode_fun(current_layer_num)
                      else 0.0),
           # then: run block
           tl.Serial(transformer._DecoderBlock(  # pylint: disable=g-complex-comprehension,protected-access
               d_model, d_ff, n_heads, dropout, [], mode, ff_activation))
           # else: noop (implicit)
           )
       # stack: embedding, n_layers_to_keep
       )
Ejemplo n.º 13
0
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
  """Returns an RNN language model.

  The input to the model is a tensor of tokens (ints).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding (n_units in the RNN cell)
    n_layers: int: number of RNN layers
    rnn_cell: the RNN cell
    rnn_cell_d_state_multiplier: how many times is RNN cell state larger
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  def MultiRNNCell():
    """Multi-layer RNN cell."""
    assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
Ejemplo n.º 14
0
def _ReversibleSerialForget(layers, d_model, n_layers, forget_dense=True):
    """ReversibleSerial but with a forgetting block every n_layers."""
    if not n_layers or len(layers) <= n_layers + 1:
        return tl.ReversibleSerial(layers)
    layers1, layers2 = layers[:n_layers], layers[n_layers:]

    if forget_dense:
        forgetting_layer = tl.Serial(
            _XYAvg(),
            tl.Dense(d_model),
            tl.Dup(),
        )
    else:
        forgetting_layer = tl.Select([0, 1])

    return tl.Serial(
        tl.ReversibleSerial(layers1), forgetting_layer,
        _ReversibleSerialForget(layers2, d_model, n_layers, forget_dense))
Ejemplo n.º 15
0
 def test_call_and_grad(self):
     layer = tl.Serial(tl.Branch(tl.Embedding(3, 4), tl.PaddingMask()),
                       sparsity.Favor(d_feature=4, n_heads=2),
                       tl.Select([0], n_in=2), tl.LogSoftmax(),
                       tl.CrossEntropyLoss())
     x = np.ones((1, 2), dtype=np.int32)
     w = np.ones_like(x).astype(np.float32)
     x_sig = shapes.signature(x)
     w_sig = shapes.signature(w)
     layer.init((x_sig, x_sig, w_sig))
     y = layer((x, x, w))
     self.assertEqual(y.shape, ())
     state = layer.state
     rng = fastmath.random.get_prng(0)
     fwd = lambda weights, inp: layer.pure_fn(inp, weights, state, rng=rng)[
         0]
     g = fastmath.grad(fwd)(layer.weights, (x, x, w))
     self.assertEqual(g[0][1][0].shape, (3, 4))
Ejemplo n.º 16
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    pre_attention_qkv = [
        tl.LayerNorm(),
        tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
    ]
    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_attention_qkv = tl.Dropout(rate=dropout, mode=mode)

    pre_causal_attention = tl.LayerNorm()
    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_causal_attention = tl.Dropout(rate=dropout, mode=mode)

    feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, mode)

    return [  # vec_d1 vec_d2 masks vec_e
        # TODO(kitaev): consider ReversibleAttentionHalfResidual for efficiency
        ReversibleHalfResidual(
            [pre_causal_attention, causal_attention, post_causal_attention]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(
            [pre_attention_qkv, attention_qkv, post_attention_qkv]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),  # vec_d1 vec_d2 masks vec_e
    ]
Ejemplo n.º 17
0
def RawPolicy(seq_model, n_controls, n_actions):
    """Wraps a sequence model in a policy interface.

  The resulting model takes as input observation anc action sequences, but only
  uses the observations. Adds output heads for action logits and value
  predictions.

  Args:
    seq_model: Trax sequence model taking as input and outputting a sequence of
      continuous vectors.
    n_controls: Number of controls.
    n_actions: Number of action categories in each control.

  Returns:
    A model of signature (obs, act) -> (act_logits, values), with shapes:
      obs: (batch_size, length + 1, obs_depth)
      act: (batch_size, length, n_controls)
      act_logits: (batch_size, length, n_controls, n_actions)
      values: (batch_size, length)
  """
    @tl.layer()
    def SplitControls(x, **unused_kwargs):  # pylint: disable=invalid-name
        """Splits logits for actions in different controls."""
        return np.reshape(x, x.shape[:2] + (n_controls, n_actions))

    action_head = [
        # Predict all action logits at the same time.
        tl.Dense(n_controls * n_actions),
        # Then group them into separate controls, adding a new dimension.
        SplitControls(),  # pylint: disable=no-value-for-parameter
        # Needed because there is 1 less actions than observations.
        DropLastTimestep(),  # pylint: disable=no-value-for-parameter
        tl.LogSoftmax(),
    ]
    return tl.Serial([  # (obs, act)
        tl.Select([0], n_in=2),  # (obs,)
        seq_model,  # (obs_hidden,)
        tl.Dup(),  # (obs_hidden, obs_hidden)
        tl.Parallel(
            action_head,
            [tl.Dense(1), tl.Flatten()],
        )  # (act_logits, values)
    ])
Ejemplo n.º 18
0
def PositionLookupTransformerLM(vocab_size=128,
                                d_model=256,
                                d_ff=512,
                                n_layers=3,
                                n_heads=4,
                                dropout=0.1,
                                max_len=100,
                                mode='train'):
    """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    positions = _POSITIONS[:max_len, :]

    decoder_blocks = [
        _DecoderBlock(positions, d_model, d_ff, n_heads, dropout, mode)
        for _ in range(n_layers)
    ]

    return tl.Serial(
        tl.ShiftRight(),
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, mode=mode),
        tl.Branch([], NewPositionalEncoding(positions=positions)),
        decoder_blocks,
        tl.Select([0], n_in=2),  # Drop positions.
        tl.LayerNorm(),
        tl.Dense(vocab_size),
        tl.LogSoftmax())
Ejemplo n.º 19
0
    def test_given_n_in(self):
        layer = tl.Select([0], n_in=2)
        self.assertEqual(layer.n_in, 2)

        layer = tl.Select([0], n_in=3)
        self.assertEqual(layer.n_in, 3)
Ejemplo n.º 20
0
# 2. `3`
# 3. `tl.Select([0,1,0,1])` 
# 4. `add` 
# 5. `mul` 
# 6. `add`. 
# 
# The `tl.Select` requires a list or tuple of 0-based indices to select elements relative to the top of the stack. For our example, the top of the stack is `3` (which is at index 0) then `4` (index 1) and we Select to add in an ordered manner to the top of the stack which after the command is `3` `4` `3` `4`. The steps of the calculation for our example are shown in the table below. As in the previous table each column shows the contents of the stack and the outputs after the operations are carried out.
# 
# <div style="text-align:center" width="20px"><img src="Stack2.png" /></div>
# 
# After processing all the inputs the stack contains 25 which is the answer we get above.

# In[ ]:


serial = tl.Serial(tl.Select([0, 1, 0, 1]), Addition(), Multiplication(), Addition())

# Initialization
x = (np.array([3]), np.array([4]))  # input

serial.init(shapes.signature(x))  # initializing serial instance


print("-- Serial Model --")
print(serial, "\n")
print("-- Properties --")
print("name :", serial.name)
print("sublayers :", serial.sublayers)
print("expected inputs :", serial.n_in)
print("promised outputs :", serial.n_out, "\n")
Ejemplo n.º 21
0
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
  """Returns an RNN language model.

  This model performs autoregressive language modeling:

    - input: rank 2 tensor representing a batch of text strings via token IDs
      plus padding markers; shape is (batch_size, sequence_length). The tensor
      elements are integers in `range(vocab_size)`, and `0` values mark padding
      positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  Args:
    vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    d_model: Embedding depth throughout the model.
    n_layers: Number of RNN layers.
    rnn_cell: Type of RNN cell; must be a subclass of `Layer`.
    rnn_cell_d_state_multiplier: Multiplier for feature depth of RNN cell
        state.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout.
    mode: If `'predict'`, use fast inference; if `'train'` apply dropout.

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """

  if n_layers != 2:  # TODO(jonni): Remove n_layers arg, if it can't vary?
    raise ValueError(f'Number of layers must be set to 2; instead got'
                     f' {n_layers}.')

  def MultiRNNCell():
    """Multi-layer RNN cell."""
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  zero_state = tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
      depth_multiplier=n_layers * rnn_cell_d_state_multiplier
  )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(vocab_size, d_model),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Branch([], zero_state),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Select([0], n_in=2),  # Drop RNN state.
      tl.Dense(vocab_size),
  )
Ejemplo n.º 22
0
def LSTMSeq2SeqAttn(input_vocab_size=256,
                    target_vocab_size=256,
                    d_model=512,
                    n_encoder_layers=2,
                    n_decoder_layers=2,
                    n_attention_heads=1,
                    attention_dropout=0.0,
                    mode='train'):
  """Returns an LSTM sequence-to-sequence model with attention.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  The model works as follows:

  * Input encoder runs on the input tokens and creates activations that
    are used as both keys and values in attention.
  * Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.
  * Attention runs on the queries, keys and values masking out input padding.
  * Decoder runs on the result, followed by a cross-entropy loss.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    target_vocab_size: Target vocabulary size.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    n_encoder_layers: Number of LSTM layers in the encoder.
    n_decoder_layers: Number of LSTM layers in the decoder after attention.
    n_attention_heads: Number of attention heads.
    attention_dropout: Stochastic rate (probability) for dropping an activation
        value when applying dropout within an attention block.
    mode: If `'predict'`, use fast inference. If `'train'`, each attention block
        will include dropout; else, it will pass all values through unaltered.

  Returns:
    An LSTM sequence-to-sequence model as a layer that maps from a
    source-target tokenized text pair to activations over a vocab set.
  """
  input_encoder = tl.Serial(
      tl.Embedding(input_vocab_size, d_model),
      [tl.LSTM(d_model) for _ in range(n_encoder_layers)],
  )

  pre_attention_decoder = tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(target_vocab_size, d_model),
      tl.LSTM(d_model),
  )

  def PrepareAttentionInputs():
    """Layer that prepares queries, keys, values and mask for attention."""
    def F(encoder_activations, decoder_activations, input_tokens):
      keys = values = encoder_activations
      queries = decoder_activations
      # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
      mask = (input_tokens != 0)
      # We need to add axes to the mask for attention heads and decoder length.
      mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
      # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
      mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
      mask = mask.astype(jnp.float32)
      return queries, keys, values, mask
    return tl.Fn('PrepareAttentionInputs', F, n_out=4)

  return tl.Serial(              # in-toks, target-toks
      tl.Select([0, 1, 0, 1]),   # in-toks, target-toks, in-toks, target-toks
      tl.Parallel(input_encoder, pre_attention_decoder),
      PrepareAttentionInputs(),  # q, k, v, mask, target-toks
      tl.Residual(
          tl.AttentionQKV(d_model, n_heads=n_attention_heads,
                          dropout=attention_dropout, mode=mode)
      ),                         # decoder-vecs, mask, target-toks
      tl.Select([0, 2]),         # decoder-vecs, target-toks
      [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
      tl.Dense(target_vocab_size),
      tl.LogSoftmax()
  )
Ejemplo n.º 23
0
def TransformerNoEncDecAttention(input_vocab_size,
                                 output_vocab_size=None,
                                 d_model=512,
                                 d_ff=2048,
                                 n_encoder_layers=6,
                                 n_decoder_layers=6,
                                 n_heads=8,
                                 dropout=0.1,
                                 dropout_shared_axes=None,
                                 max_len=2048,
                                 mode='train',
                                 ff_activation=tl.Relu):
  """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  def PositionalEncoder(vocab_size):  # tokens --> vectors
    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=dropout_shared_axes, mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]

  in_encoder = PositionalEncoder(input_vocab_size)
  out_encoder = (in_encoder if output_vocab_size is None
                 else PositionalEncoder(output_vocab_size))
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size

  encoder_blocks = [
      transformer._EncoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_encoder_layers)]

  encoder = tl.Serial(
      in_encoder,
      encoder_blocks,
      tl.LayerNorm()
  )
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = [
      transformer._DecoderBlock(d_model, d_ff, n_heads, dropout,  # pylint: disable=protected-access
                                dropout_shared_axes, mode, ff_activation)
      for i in range(n_decoder_layers)]

  # pylint: disable=protected-access
  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),          # tok_e tok_e tok_d tok_d

      # Encode.
      tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
      encoder,                          # vec_e mask_e tok_e tok_d tok_d

      # Simple encoder mask, doesn't contain extra dims.
      tl.Select([2, 0, 2], n_in=3),     # tok_e vec_e tok_e tok_d tok_d
      transformer._MaskOfRightShiftedArray(
          n_positions=0),               # mask_e vec_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 1, 0, 2]),          #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),         # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          transformer._MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder.
      tl.Select([2, 0, 3, 1]),          # vec_e svec_d mask_e mask_d tok_e tok_d
      transformer._ConcatWithPadding(),  # vec_ed tok_e tok_d

      # Decoder blocks with causal attention
      decoder_blocks,                   # vec_ed tok_e tok_d
      tl.LayerNorm(),                   # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),          # vec_ed tok_e tok_d tok_d
      transformer._StripFromConcatenateWithPadding(),  # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),      # vec_d tok_d
      tl.LogSoftmax(),                  # vec_d tok_d
  )
Ejemplo n.º 24
0
 def test_second_of_3(self):
     layer = tl.Select([1], n_in=3)
     xs = [np.array([1, 2, 3]), np.array([10, 20]), np.array([100])]
     y = layer(xs)
     self.assertEqual(as_list(y), [10, 20])
Ejemplo n.º 25
0
def Transformer2(input_vocab_size,
                 output_vocab_size=None,
                 d_model=512,
                 d_ff=2048,
                 n_encoder_layers=6,
                 n_decoder_layers=6,
                 n_heads=8,
                 dropout=0.1,
                 dropout_shared_axes=None,
                 max_len=2048,
                 mode='train',
                 ff_activation=tl.Relu,
                 ff_dropout=0.1,
                 ff_chunk_size=0,
                 ff_use_sru=0,
                 ff_sparsity=0,
                 ff_sparsity_type='1inN',
                 attention_chunk_size=0,
                 encoder_attention_type=tl.Attention,
                 n_encoder_attention_layers=1,
                 decoder_attention_type=tl.CausalAttention,
                 n_decoder_attention_layers=2,
                 axial_pos_shape=None,
                 d_axial_pos_embs=None):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
    ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    encoder_attention_type: The attention layer to use for the encoder part.
    n_encoder_attention_layers: int, within each encoder block, how many
      attention layers to have.
    decoder_attention_type: The attention layer to use for the
      encoder-decoder attention.
    n_decoder_attention_layers: int, within each decoder block, how many
      attention layers to have.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_encoder, out_encoder, output_vocab_size = (
        ct.EmbeddingAndPositionalEncodings(input_vocab_size,
                                           d_model,
                                           mode,
                                           dropout,
                                           dropout_shared_axes,
                                           max_len,
                                           output_vocab_size=output_vocab_size,
                                           axial_pos_shape=axial_pos_shape,
                                           d_axial_pos_embs=d_axial_pos_embs))

    # pylint: disable=g-complex-comprehension
    encoder_blocks = [
        ct.EncoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, encoder_attention_type,
                        n_encoder_attention_layers)
        for i in range(n_encoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    encoder = tl.Serial(in_encoder, encoder_blocks, tl.LayerNorm())
    if mode == 'predict':
        encoder = tl.Cache(encoder)

    # pylint: disable=g-complex-comprehension
    decoder_blocks = [
        ct.DecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                        mode, ff_activation, ff_dropout, ff_chunk_size,
                        ff_use_sru, ff_sparsity, ff_sparsity_type,
                        attention_chunk_size, decoder_attention_type,
                        n_decoder_attention_layers)
        for i in range(n_decoder_layers)
    ]
    # pylint: enable=g-complex-comprehension

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 0, 1, 1]),  # tok_e tok_e tok_d tok_d

        # Encode.
        tl.Branch([], tl.PaddingMask()),  # tok_e mask_e tok_e tok_d tok_d
        encoder,  # vec_e mask_e tok_e tok_d tok_d

        # Simple encoder mask, doesn't contain extra dims.
        tl.Select([2, 0, 2], n_in=3),  #  tok_e vec_e tok_e tok_d tok_d
        tl.Fn(
            'EncoderMask',  # mask_e vec_e tok_e tok_d tok_d
            lambda x: x != 0,
            n_out=1),

        # Decode.
        tl.Select([3, 1, 0, 2]),  #  tok_d vec_e mask_e tok_e tok_d
        tl.ShiftRight(mode=mode),  # stok_d vec_e mask_e tok_e tok_d
        out_encoder,  # svec_d vec_e mask_e tok_e tok_d

        # Concat encoder and decoder.
        tl.Select([1, 0]),  # vec_e svec_d mask_e tok_e tok_d
        ConcatWithPadding(mode=mode),  # vec_ed tok_e tok_d

        # Decoder blocks with causal attention
        decoder_blocks,  # vec_ed tok_e tok_d
        tl.LayerNorm(),  # vec_ed tok_e tok_d

        # Separate out the encoder part from the concatenated vector.
        tl.Select([0, 1, 2, 2]),  # vec_ed tok_e tok_d tok_d
        StripFromConcatenateWithPadding(mode=mode),  # vec_d tok_d

        # Map to output vocab.
        tl.Dense(output_vocab_size),  # vec_d tok_d
    )
Ejemplo n.º 26
0
def ReformerNoEncDecAttention(input_vocab_size,
                              output_vocab_size=None,
                              d_model=512,
                              d_ff=2048,
                              d_attention_key=64,
                              d_attention_value=64,
                              n_encoder_layers=6,
                              n_decoder_layers=6,
                              n_heads=8,
                              dropout=0.1,
                              max_len=2048,
                              encoder_attention_type=tl.SelfAttention,
                              encoder_decoder_attention_type=tl.SelfAttention,
                              axial_pos_shape=(),
                              d_axial_pos_embs=None,
                              ff_activation=tl.Relu,
                              ff_use_sru=0,
                              ff_chunk_size=0,
                              ff_dropout=None,
                              mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: source, target.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    encoder_attention_type: class: attention class to use, such as SelfAttention
    encoder_decoder_attention_type: class: attention class to use, such as
      SelfAttention
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, and values must sum to d_model.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  if fastmath.backend_name() == 'jax':
    jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    if not axial_pos_shape:
      positional_encoding = tl.PositionalEncoding(
          max_len=max_len, dropout=dropout, mode=mode)
    else:
      assert d_axial_pos_embs is not None
      positional_encoding = tl.AxialPositionalEncoding(
          shape=axial_pos_shape, d_embs=d_axial_pos_embs,
          dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
          dropout=dropout, mode=mode)

    return [
        tl.Embedding(vocab_size, d_model),
        tl.Dropout(rate=dropout, shared_axes=[-2], mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  # pylint: disable=g-complex-comprehension
  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, encoder_attention_type, dropout,
          ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]
  # pylint: enable=g-complex-comprehension

  encoder = tl.Serial([                # tok_e mask_e tok_e tok_d tok_d
      in_encoder,                      # vec_e mask_e tok_e tok_d tok_d
      tl.Dup(),                        # vec_e1 vec_e2 mask_e tok_e tok_d tok_d
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn('XYAvg', lambda x, y: (x + y) / 2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  decoder_blocks = []

  if isinstance(encoder_decoder_attention_type, (tuple, list)):
    assert n_decoder_layers % len(encoder_decoder_attention_type) == 0
  else:
    encoder_decoder_attention_type = [encoder_decoder_attention_type]
  for layer_idx in range(n_decoder_layers):
    layer_attention_type = encoder_decoder_attention_type[
        layer_idx % len(encoder_decoder_attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        attention_type=layer_attention_type,
        dropout=dropout,
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 0, 1, 1]),                  # tok_e tok_e tok_d tok_d
      tl.Branch([], [tl.PaddingMask(),
                     tl.Fn('Squeeze',
                           lambda x: jnp.squeeze(x, (1, 2)), n_out=1)]),
      #                                         # tok_e mask_e tok_e tok_d tok_d

      # Encode.
      encoder,                                  # vec_e mask_e tok_e tok_d tok_d

      # Decode.
      tl.Select([3, 0, 1, 2]),                 #  tok_d vec_e mask_e tok_e tok_d
      tl.ShiftRight(mode=mode),                # stok_d vec_e mask_e tok_e tok_d
      tl.Branch(
          [],
          _MaskOfRightShiftedArray()
      ),                                # stok_d mask_d vec_e mask_e tok_e tok_d
      out_encoder,                      # svec_d mask_d vec_e mask_e tok_e tok_d

      # Concat encoder and decoder, given their masks.
      tl.Select([2, 0, 3, 1]),          # svec_d mask_d vec_e mask_e tok_e tok_d
      _ConcatWithPadding(),                        # vec_ed tok_e tok_d

      # Run (encoder and) decoder blocks.
      tl.Dup(),                                    # vec_ed1 vec_ed2 tok_e tok_d
      tl.ReversibleSerial(decoder_blocks),         # vec_ed1 vec_ed2 tok_e tok_d
      tl.Fn('XYAvg',
            lambda x, y: (x + y) / 2.0),           # vec_ed tok_e tok_d
      tl.LayerNorm(),                              # vec_ed tok_e tok_d

      # Separate out the encoder part from the concatenated vector.
      tl.Select([0, 1, 2, 2]),                     # vec_ed tok_e tok_d tok_d
      _StripFromConcatenateWithPadding(),          # vec_d tok_d

      # Map to output vocab.
      tl.Dense(output_vocab_size),                 # vec_d tok_d
      tl.LogSoftmax(),                             # vec_d tok_d
  )
Ejemplo n.º 27
0
 def advantage_std(self):
     return tl.Serial([
         # (dist_inputs, advantages, old_dist_inputs, mask)
         tl.Select([1]),  # Select just the advantages.
         tl.Fn('AdvantageStd', lambda x: jnp.std(x)),  # pylint: disable=unnecessary-lambda
     ])
Ejemplo n.º 28
0
def Reformer(input_vocab_size,
             output_vocab_size=None,
             d_model=512,
             d_ff=2048,
             n_encoder_layers=6,
             n_decoder_layers=6,
             n_heads=8,
             dropout=0.1,
             max_len=2048,
             ff_activation=tl.Relu,
             ff_dropout=None,
             mode='train'):
  """Reversible transformer encoder-decoder model.

  This model expects an input pair: target, source.

  At the moment, this model supports dot-product attention only. For the
  attention types in the Reformer paper, see ReformerLM.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate at feed-forward
      nonlinearity. This is called relu_dropout in T2T.
    mode: str: 'train' or 'eval'

  Returns:
    A Reformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
  # The current API for custom gradients assumes that a layer must be
  # differentiable wrt all of its inputs, but the Transformer puts bool-dtype
  # masks on the stack. This causes jax to error, even though the so-called
  # "gradient" wrt the masks is never actually computed.
  # TODO(kitaev): remove this hack.
  jax.api._check_inexact_input_vjp = lambda x: None  # pylint: disable=protected-access

  def PositionalEncoder(vocab_size, mode):  # tokens --> vectors
    # TODO(kitaev): axial positional encoding is better for very long sequences.
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
    return [
        tl.Embedding(d_model, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),
        positional_encoding,
    ]

  # TODO(kitaev): The regular trax Transformer shares vocab embeddings and
  # position embeddings between the encoder and decoder if output_vocab_size is
  # None. This isn't supported here because (a) Trax shares weights by sharing
  # layer instances, but we need two separate instances to have mode == 'eval'
  # for the encoder but mode == 'predict' for the decoder; and (b) tl.Cache does
  # not work if its sublayers participate in any weight sharing.

  # Mode 'predict' means that the decoder should be run one token at a time.
  # The encoder only ever runs over full sequences, which is why it's switched
  # to 'eval' mode instead.
  in_encoder = PositionalEncoder(
      input_vocab_size, mode='eval' if mode == 'predict' else mode)
  if output_vocab_size is None:
    output_vocab_size = input_vocab_size
  out_encoder = PositionalEncoder(output_vocab_size, mode)

  encoder_blocks = [
      EncoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_encoder_layers)]

  encoder = tl.Serial([
      in_encoder,
      tl.Dup(),
      tl.ReversibleSerial(encoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),
      tl.LayerNorm(),
  ])
  if mode == 'predict':
    encoder = tl.Cache(encoder)

  encoder_decoder_blocks = [
      EncoderDecoderBlock(
          d_model, d_ff, n_heads, dropout, ff_activation, ff_dropout, mode)
      for _ in range(n_decoder_layers)]

  # Assemble and return the model.
  return tl.Serial(
      # Input: encoder_side_tokens, decoder_side_tokens
      # Copy decoder tokens for use in loss.
      tl.Select([0, 1, 1]),                 # tok_e tok_d tok_d
      tl.Branch([], [                       # tok_e mask  tok_d .....
          tl.PaddingMask(),
          tl.Fn(lambda x: np.squeeze(x, (1, 2)), n_out=1)]),

      # Encode.
      encoder,                              # vec_e  mask tok_d .....

      # Decode.
      tl.Select([2, 0, 1]),                 # tok_d vec_e mask .....
      tl.ShiftRight(mode=mode),             # tok_d vec_e mask .....
      out_encoder,                          # vec_d vec_e mask .....
      tl.Dup(),                             # vec_d1 vec_d2 vec_e mask .....
      tl.ReversibleSerial(encoder_decoder_blocks),
      tl.Fn(lambda x, y: (x+y)/2.0),        # vec_d vec_e mask .....
      tl.LayerNorm(),                       # vec_d vec_e mask .....

      # Map to output vocab.
      tl.Select([0], n_in=3),               # vec_d .....
      tl.Dense(output_vocab_size),          # vec_d .....
      tl.LogSoftmax(),                      # vec_d .....
  )
Ejemplo n.º 29
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert
Ejemplo n.º 30
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      ff_chunk_size=0,
                      mode='train'):
  """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  assert mode != 'predict'  # TODO(lukaszkaiser,kitaev): fast inference

  if not axial_pos_shape:
    positional_encoding = tl.PositionalEncoding(
        max_len=max_len, dropout=dropout, mode=mode)
  else:
    assert d_axial_pos_embs is not None
    positional_encoding = tl.AxialPositionalEncoding(
        shape=axial_pos_shape, d_embs=d_axial_pos_embs,
        dropout_broadcast_dims=tuple(range(1, len(axial_pos_shape) + 1)),
        dropout=dropout, mode=mode)

  positional_embedder = [
      tl.Embedding(d_embedding, vocab_size),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      positional_encoding,
  ]

  decoder_blocks = []

  if isinstance(attention_type, (tuple, list)):
    assert n_layers % len(attention_type) == 0
  else:
    attention_type = [attention_type]
  for layer_idx in range(n_layers):
    layer_attention_type = attention_type[layer_idx % len(attention_type)]
    decoder_block = DecoderBlock(
        d_model, d_ff, d_attention_key, d_attention_value, n_heads,
        n_attention_chunks,
        attention_type=layer_attention_type,
        dropout=dropout,
        share_qk=(share_qk or issubclass(layer_attention_type,
                                         tl.LSHCausalAttention)),
        ff_activation=ff_activation,
        ff_use_sru=ff_use_sru,
        ff_chunk_size=ff_chunk_size,
        mode=mode)
    decoder_blocks.append(decoder_block)

  # pylint: disable=g-long-lambda
  return tl.Serial(
      tl.ShiftRight(),
      positional_embedder,
      tl.Dup(),              # Stack has (x, x), the first will be shortened
      # Before shortening, we need to pad by shorten factor so as not to leak
      # information into the future. To understand why, imagine shorten factor
      # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
      # would have 0ABC, which gets grouped to [0A][BC] on input, which is
      # predicting ABCD as targets. The problem is that [0A] has access to A
      # and [BC] has access to C -- it will learn to copy it, peek into
      # the future. Shifting twice to [00][AB] solves the problem as the first
      # "big" symbol becomes all-0 and the rest is shifted enough.
      tl.ShiftRight(n_shifts=shorten_factor - 1),
      tl.Fn(lambda x: np.reshape(  # Shorten -- move to depth.
          x, (x.shape[0], x.shape[1] // shorten_factor, -1)), n_out=1),
      tl.Dense(d_model),
      tl.Dup(),  # Stack has (short_x, short_x, x)
      tl.ReversibleSerial(decoder_blocks),
      tl.Select([0], n_in=2),
      tl.LayerNorm(),
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
      tl.Dense(shorten_factor * d_embedding),
      tl.Fn(lambda x: np.reshape(  # Prolong back.
          x, (x.shape[0], x.shape[1] * shorten_factor, -1)), n_out=1),
      tl.Concatenate(),  # Concatenate with just the embeddings.
      tl.CausalConv(d_embedding),
      tl.Relu(),
      tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )