Example #1
0
def AttentionPosition(positions,
                      d_model,
                      n_heads=8,
                      dropout=0.0,
                      mode='train'):
    """Transformer-style multi-headed attention."""
    return tl.Serial(
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            ApplyAndQueryPositions(
                tl.Dense(d_model),
                pos=[SumLearnedPick(positions) for _ in range(n_heads)]),
            PreservePosition(tl.Dense(d_model)),
            PreservePosition(tl.Dense(d_model)),
        ),
        tl.Parallel(
            CopyHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
            MixHeadsPos(h=n_heads),
        ),
        tl.PureAttention(d_model=d_model,
                         n_heads=n_heads,
                         dropout=dropout,
                         mode=mode),
        tl.Parallel([], tl.Drop()),  # Drop the mask.
        CombineHeadsPos(h=n_heads),
        PreservePosition(tl.Dense(d_model)),
    )
Example #2
0
def SerializedModel(
    seq_model,
    observation_serializer,
    action_serializer,
    significance_decay,
):
    """Wraps a world model in serialization machinery for training.

  The resulting model takes as input the observation and action sequences,
  serializes them and interleaves into one sequence, which is fed into a given
  autoregressive model. The resulting logit sequence is deinterleaved into
  observations and actions, and the observation logits are returned together
  with computed symbol significance weights.

  Args:
    seq_model: Trax autoregressive model taking as input a sequence of symbols
      and outputting a sequence of symbol logits.
    observation_serializer: Serializer to use for observations.
    action_serializer: Serializer to use for actions.
    significance_decay: Float from (0, 1) for exponential weighting of symbols
      in the representation.

  Returns:
    A model of signature
    (obs, act, obs, mask) -> (obs_logits, obs_repr, weights), where obs are
    observations (the second occurrence is the target), act are actions, mask is
    the observation mask, obs_logits are logits of the output observation
    representation, obs_repr is the target observation representation and
    weights are the target weights.
  """
    # pylint: disable=no-value-for-parameter
    weigh_by_significance = [
        # (mask,)
        RepresentationMask(serializer=observation_serializer),
        # (repr_mask)
        SignificanceWeights(serializer=observation_serializer,
                            decay=significance_decay),
        # (mask, sig_weights)
    ]
    return tl.Serial(
        # (obs, act, obs, mask)
        tl.Parallel(Serialize(serializer=observation_serializer),
                    Serialize(serializer=action_serializer),
                    Serialize(serializer=observation_serializer)),
        # (obs_repr, act_repr, obs_repr, mask)
        Interleave(),
        # (obs_act_repr, obs_repr, mask)
        seq_model,
        # (obs_act_logits, obs_repr, mask)
        Deinterleave(x_size=observation_serializer.representation_length,
                     y_size=action_serializer.representation_length),
        # (obs_logits, act_logits, obs_repr, mask)
        tl.Parallel(None, tl.Drop(), None, weigh_by_significance),
        # (obs_logits, obs_repr, weights)
    )
Example #3
0
    def __init__(
        self,
        seq_model,
        observation_serializer,
        action_serializer,
        significance_decay,
        mode='train',
    ):
        """Initializes SerializedModel.

    Args:
      seq_model: Trax autoregressive model taking as input a sequence of symbols
        and outputting a sequence of symbol logits.
      observation_serializer: Serializer to use for observations.
      action_serializer: Serializer to use for actions.
      significance_decay: Float from (0, 1) for exponential weighting of symbols
        in the representation.
      mode: 'train' or 'eval'.
    """
        assert mode in ('train', 'eval')
        weigh_by_significance = [
            # (mask,)
            RepresentationMask(serializer=observation_serializer),
            # (repr_mask)
            SignificanceWeights(serializer=observation_serializer,
                                decay=significance_decay),
            # (mask, sig_weights)
        ]
        super().__init__(
            # (obs, act, obs, mask)
            tl.Parallel(Serialize(serializer=observation_serializer),
                        Serialize(serializer=action_serializer),
                        Serialize(serializer=observation_serializer)),
            # (obs_repr, act_repr, obs_repr, mask)
            Interleave(),
            # (obs_act_repr, obs_repr, mask)
            seq_model(mode=mode),
            # (obs_act_logits, obs_repr, mask)
            Deinterleave(x_size=observation_serializer.representation_length,
                         y_size=action_serializer.representation_length),
            # (obs_logits, act_logits, obs_repr, mask)
            tl.Parallel(None, tl.Drop(), None, weigh_by_significance),
            # (obs_logits, obs_repr, weights)
        )

        self._seq_model = seq_model
        self._observation_serializer = observation_serializer
        self._action_serializer = action_serializer
def TransformerEncoder(vocab_size,
                       n_classes=10,
                       d_model=512,
                       d_ff=2048,
                       n_layers=6,
                       n_heads=8,
                       dropout=0.1,
                       max_len=2048,
                       mode='train',
                       ff_activation=tl.Relu):
    """Returns a Transformer encoder model.

  The input to the model is a tensor of tokens.

  Args:
    vocab_size: int: vocab size
    n_classes: how many classes on output
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of encoder/decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a tensor of tokens to
    activations over a set of output classes.
  """
    embedder = [
        tl.Embedding(d_model, vocab_size),
        tl.Dropout(rate=dropout, name='emb_dropout', mode=mode),
        tl.PositionalEncoding(max_len=max_len),
    ]
    return tl.Serial(  #      tokens
        tl.Dup(),  # toks toks
        tl.Parallel(embedder, tl.PaddingMask()),  # vecs mask
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                         ff_activation) for i in range(n_layers)
        ],  # vecs mask
        tl.Parallel([], tl.Drop()),  # ____  0
        tl.LayerNorm(),  # vecs
        tl.Mean(axis=1),  # Average on length.    # vecs
        tl.Dense(n_classes),  # vecs
        tl.LogSoftmax(),  # vecs
    )
Example #5
0
def RNNLM(vocab_size,
          d_model=512,
          n_layers=2,
          rnn_cell=tl.LSTMCell,
          rnn_cell_d_state_multiplier=2,
          dropout=0.1,
          mode='train'):
  """Returns an RNN language model.

  The input to the model is a tensor of tokens (ints).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding (n_units in the RNN cell)
    n_layers: int: number of RNN layers
    rnn_cell: the RNN cell
    rnn_cell_d_state_multiplier: how many times is RNN cell state larger
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    An RNN language model as a layer that maps from a tensor of tokens
    to activations over a vocab set.
  """
  def MultiRNNCell():
    """Multi-layer RNN cell."""
    assert n_layers == 2
    return tl.Serial(
        tl.Parallel([], tl.Split(n_items=n_layers)),
        tl.SerialWithSideOutputs(
            [rnn_cell(n_units=d_model) for _ in range(n_layers)]),
        tl.Parallel([], tl.Concatenate(n_items=n_layers))
    )

  return tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, name='embedding', mode=mode),
      tl.Dup(),  # Duplicate to create parallel state.
      tl.Parallel([], tl.MakeZeroState(  # pylint: disable=no-value-for-parameter
          depth_multiplier=n_layers * rnn_cell_d_state_multiplier)),
      tl.Scan(MultiRNNCell(), axis=1),
      tl.Parallel([], tl.Drop()),  # Drop RNN state.
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
def PositionLookupTransformerLM(vocab_size=128,
                                d_model=256,
                                d_ff=512,
                                n_layers=3,
                                n_heads=4,
                                dropout=0.1,
                                max_len=100,
                                mode='train'):
  """Transformer language model (only uses the decoder part of Transformer).

  Args:
    vocab_size: int: vocab size
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_layers: int: number of layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: maximal length
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  positions = _POSITIONS[:max_len, :]
  return tl.Serial(
      tl.ShiftRight(),
      tl.Embedding(d_model, vocab_size),
      tl.Dropout(rate=dropout, mode=mode),
      tl.Dup(),
      tl.Parallel([], NewPositionalEncoding(positions=positions)),
      [DecoderLayer(positions, d_model, d_ff, n_heads, dropout, mode)
       for _ in range(n_layers)],
      tl.Parallel([], tl.Drop()),  # Drop positions.
      tl.LayerNorm(),
      tl.Dense(vocab_size),
      tl.LogSoftmax()
  )
Example #7
0
 def test_drop(self):
     layer = tl.Drop()
     x = np.array([1, 2, 3])
     y = layer(x)
     self.assertEqual(as_list(y), [])
Example #8
0
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss':
    tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss()),
    'accuracy':
    tl.Accuracy(),
    'sequence_accuracy':
    tl.SequenceAccuracy(),
    'neg_log_perplexity':
    tl.Serial(tl.LogSoftmax(), tl.CrossEntropyLoss(), tl.Negate()),
    'weights_per_batch_per_core':
    tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
Example #9
0
OptState = collections.namedtuple(
    '_OptState',
    [
        'weights',  # Model weights.
        'slots',  # Per-parameter optimizer state, e.g. gradient moments.
        'opt_params',  # Optimizer (hyper)parameters, e.g. learning rate, momentum.
    ])

_DEFAULT_METRICS = {
    'loss': tl.WeightedCategoryCrossEntropy(),
    'accuracy': tl.WeightedCategoryAccuracy(),
    'sequence_accuracy': tl.MaskedSequenceAccuracy(),
    'neg_log_perplexity': tl.Serial(tl.WeightedCategoryCrossEntropy(),
                                    tl.Negate()),
    'weights_per_batch_per_core': tl.Serial(tl.Drop(), tl.Drop(), tl.Sum()),
}


class Trainer:
    """Trax trainer.

  A trainer allows to make training steps, train for full epochs,
  save the training state and access evaluation data.
  """
    def __init__(self,
                 model,
                 loss_fn,
                 optimizer,
                 lr_schedule,
                 inputs,
Example #10
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train'):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    in_embed = [  # tokens
        tl.Embedding(d_model, input_vocab_size),  # vecs
        tl.Dropout(rate=dropout, mode=mode),  # vecs
        tl.PositionalEncoding(max_len=max_len),  # vecs
    ]

    if output_vocab_size is None:
        output_vocab_size = input_vocab_size
        out_embed = in_embed
    else:
        out_embed = [  # tokens
            tl.Embedding(d_model, output_vocab_size),  # vecs
            tl.Dropout(rate=dropout, mode=mode),  # vecs
            tl.PositionalEncoding(max_len=max_len),  # vecs
        ]

    encoder_stack = (  # masks vectors --> masks vectors
        [
            EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_encoder_layers)
        ])

    encoder_decoder_stack = (  # vecs_d masks vecs_e --> vecs_d masks vecs_e
        [
            EncoderDecoder(d_model, d_ff, n_heads, dropout, i, mode)
            for i in range(n_decoder_layers)
        ])

    # Input: encoder_side_tokens, decoder_side_tokens
    return tl.Serial(  # tokens_e tokens_d
        tl.Parallel([], tl.Dup()),  # toks_e toks_d toks_d (for loss)
        tl.Swap(),  # toks_d toks_e ....

        # Encode.
        tl.Parallel(  # toks_d        toks_e
            [],
            [
                tl.Dup(),  # ______ toks_e toks_e
                tl.Parallel(in_embed, tl.PaddingMask()),  # ______ vecs_e masks
                encoder_stack,  # ______ vecs_e masks
                tl.LayerNorm(),  # ______ vecs_e .....
                tl.Swap()
            ]),  # ______ masks  vecs_e

        # Decode.                                  #        toks_d masks vecs_e
        tl.ShiftRight(),  #        toks_d ..... ......
        out_embed,  #        vecs_d ..... ......
        tl.Dup(),  # vecs_d vecs_d ..... ......
        tl.Parallel([], tl.EncoderDecoderMask()),  # ______    masks     ......
        encoder_decoder_stack,  # vecs_d    masks     vecs_e
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vecs_d
        tl.LayerNorm(),  # vecs_d
        tl.Dense(output_vocab_size),  # vecs_d
        tl.LogSoftmax(),  # vecs_d
    )
Example #11
0
def Transformer(input_vocab_size,
                output_vocab_size=None,
                d_model=512,
                d_ff=2048,
                n_encoder_layers=6,
                n_decoder_layers=6,
                n_heads=8,
                dropout=0.1,
                max_len=2048,
                mode='train',
                ff_activation=tl.Relu):
    """Returns a Transformer model.

  This model expects an input pair: target, source.

  Args:
    input_vocab_size: int: vocab size of the source.
    output_vocab_size: int (optional): vocab size of the target. If None, the
      source and target are assumed to have the same vocab.
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_encoder_layers: int: number of encoder layers
    n_decoder_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A Transformer model as a layer that maps from a target, source pair to
    activations over a vocab set.
  """
    def PositionalEmbedder(vocab_size):  # tokens --> vectors
        return [
            tl.Embedding(d_model, vocab_size),
            tl.Dropout(rate=dropout, mode=mode),
            tl.PositionalEncoding(max_len=max_len),
        ]

    def EncoderBlocks(n_blocks):  # vectors masks --> vectors masks
        return [
            _EncoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                          ff_activation) for i in range(n_blocks)
        ]

    def EncoderDecoderBlocks(n_blocks):  # vectors masks --> vectors masks
        return [
            _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, i, mode,
                                 ff_activation) for i in range(n_blocks)
        ]

    in_embed = PositionalEmbedder(input_vocab_size)
    out_embed = (in_embed if output_vocab_size is None else
                 PositionalEmbedder(output_vocab_size))
    if output_vocab_size is None:
        output_vocab_size = input_vocab_size

    # Assemble and return the model.
    return tl.Serial(
        # Input: encoder_side_tokens, decoder_side_tokens
        # Copy decoder tokens for use in loss.
        tl.Select([0, 1, 1]),  # tok_e tok_d tok_d

        # Encode.
        tl.Branch(in_embed, tl.PaddingMask()),  # vec_e masks ..... .....
        EncoderBlocks(n_encoder_layers),  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_e ..... ..... .....

        # Decode.
        tl.Select([2, 1, 0]),  # tok_d masks vec_e .....
        tl.ShiftRight(),  # tok_d ..... ..... .....
        out_embed,  # vec_d ..... ..... .....
        tl.Branch([], tl.EncoderDecoderMask()),  # vec_d masks ..... .....
        EncoderDecoderBlocks(n_decoder_layers),  # vec_d masks ..... .....
        tl.LayerNorm(),  # vec_d ..... ..... .....

        # Map to output vocab.
        tl.Parallel([], tl.Drop(), tl.Drop()),  # vec_d tok_d
        tl.Dense(output_vocab_size),  # vec_d .....
        tl.LogSoftmax(),  # vec_d .....
    )
Example #12
0
def ReformerShortenLM(vocab_size,
                      shorten_factor=1,
                      d_embedding=256,
                      d_model=512,
                      d_ff=2048,
                      d_attention_key=64,
                      d_attention_value=64,
                      n_layers=6,
                      n_heads=8,
                      dropout=0.1,
                      max_len=2048,
                      n_attention_chunks=1,
                      attention_type=tl.DotProductCausalAttention,
                      share_qk=False,
                      axial_pos_shape=(),
                      d_axial_pos_embs=None,
                      ff_activation=tl.FastGelu,
                      ff_use_sru=0,
                      mode='train'):
    """Reversible transformer language model with shortening.

  When shorten_factor is F and processing an input of shape [batch, length],
  we embed the (shifted-right) input and then group each F elements (on length)
  into a single vector -- so that in the end we process a tensor of shape
    [batch, length // F, d_model]
  almost until the end -- at the end it's un-shortend and a SRU is applied.
  This reduces the length processed inside the main model body, effectively
  making the model faster but possibly slightly less accurate.

  Args:
    vocab_size: int: vocab size
    shorten_factor: by how much to shorten, see above
    d_embedding: the depth of the embedding layer and final logits
    d_model: int:  depth of *each half* of the two-part features
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_layers: int: number of decoder layers
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    max_len: int: maximum symbol length for positional encoding
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    share_qk: bool, whether to share queries and keys.
    axial_pos_shape: tuple of ints: input shape to use for the axial position
      encoding. If unset, axial position encoding is disabled.
    d_axial_pos_embs: tuple of ints: depth of position embedding for each axis.
      Tuple length must match axial_pos_shape, values must sum to d_embedding.
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    if not axial_pos_shape:
        positional_encoding = tl.PositionalEncoding(max_len=max_len,
                                                    dropout=dropout)
    else:
        assert d_axial_pos_embs is not None
        positional_encoding = tl.AxialPositionalEncoding(
            shape=axial_pos_shape,
            d_embs=d_axial_pos_embs,
            dropout_broadcast_dims=tuple(range(1,
                                               len(axial_pos_shape) + 1)),
            dropout=dropout)

    positional_embedder = [
        tl.Embedding(d_embedding, vocab_size),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        positional_encoding,
    ]

    decoder_blocks = []

    if isinstance(attention_type, (tuple, list)):
        assert n_layers % len(attention_type) == 0
    else:
        attention_type = [attention_type]
    for layer_idx in range(n_layers):
        layer_attention_type = attention_type[layer_idx % len(attention_type)]
        decoder_block = DecoderBlock(
            d_model,
            d_ff,
            d_attention_key,
            d_attention_value,
            n_heads,
            n_attention_chunks,
            attention_type=layer_attention_type,
            dropout=dropout,
            share_qk=(share_qk or issubclass(layer_attention_type,
                                             tl.LSHCausalAttention)),
            ff_activation=ff_activation,
            ff_use_sru=ff_use_sru,
            mode=mode)
        decoder_blocks.append(decoder_block)

    # pylint: disable=g-long-lambda
    return tl.Serial(
        tl.ShiftRight(),
        positional_embedder,
        tl.Dup(),  # Stack has (x, x), the first will be shortened
        # Before shortening, we need to pad by shorten factor so as not to leak
        # information into the future. To understand why, imagine shorten factor
        # of 2 and sequence of length 4, so ABCD. If we shift just by 1, then we
        # would have 0ABC, which gets grouped to [0A][BC] on input, which is
        # predicting ABCD as targets. The problem is that [0A] has access to A
        # and [BC] has access to C -- it will learn to copy it, peek into
        # the future. Shifting twice to [00][AB] solves the problem as the first
        # "big" symbol becomes all-0 and the rest is shifted enough.
        tl.ShiftRight(n_shifts=shorten_factor - 1),
        tl.Fn(
            lambda x: np.reshape(  # Shorten -- move to depth.
                x, (x.shape[0], x.shape[1] // shorten_factor, -1)),
            n_out=1),
        tl.Dense(d_model),
        tl.Dup(),  # Stack has (short_x, short_x, x)
        tl.ReversibleSerial(decoder_blocks),
        tl.Parallel([], tl.Drop()),
        tl.LayerNorm(),
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
        tl.Dense(shorten_factor * d_embedding),
        tl.Fn(
            lambda x: np.reshape(  # Prolong back.
                x, (x.shape[0], x.shape[1] * shorten_factor, -1)),
            n_out=1),
        tl.Concatenate(),  # Concatenate with just the embeddings.
        tl.CausalConv(d_embedding),
        tl.Relu(),
        tl.SRU(d_embedding),  # One RNN layer for conditional dependence.
        tl.Dense(vocab_size),
        tl.LogSoftmax())