Example #1
0
def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            d_model=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):
    """Returns an LSTM sequence-to-sequence model with attention.

    The input to the model is a pair (input tokens, target tokens), e.g.,
    an English sentence (tokenized) and its translation into German (tokenized).

    Args:
    input_vocab_size: int: vocab size of the input
    target_vocab_size: int: vocab size of the target
    d_model: int:  depth of embedding (n_units in the LSTM cell)
    n_encoder_layers: int: number of LSTM layers in the encoder
    n_decoder_layers: int: number of LSTM layers in the decoder after attention
    n_attention_heads: int: number of attention heads
    attention_dropout: float, dropout for the attention layer
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

    Returns:
    A LSTM sequence-to-sequence model with attention.
    """

    # creation of input encoder for encoder activations
    input_encoder = input_encoder_fn(input_vocab_size, d_model, n_encoder_layers)

    # creation of layers for the pre-attention decoder
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size, d_model)

    # Model
    model = tl.Serial(

        # copy input tokens and target tokens for later use.
        tl.Select([0, 1, 0, 1]),

       # parellel run of input encoder on the input and pre-attention decoder the target.
        tl.Parallel(input_encoder, pre_attention_decoder),

        # preparation of queries, keys, values and mask for attention.
        tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),

        # AttentionQKV layer nested it inside a Residual layer to add to the pre-attention decoder activations
        tl.Residual(tl.AttentionQKV(d_model, n_heads=n_attention_heads, dropout=attention_dropout, mode=mode)),
        tl.Select([0, 2]),

        # run the rest of the RNN decoder
        [tl.LSTM(n_units=d_model) for _ in range(n_decoder_layers)],

        # Dense layer of target size
        tl.Dense(target_vocab_size),

       #Log-softmax for output
        tl.LogSoftmax()
    )

    return model
Example #2
0
def NMTAttn(input_vocab_size=33300,
            target_vocab_size=33300,
            d_model=1024,
            n_encoder_layers=2,
            n_decoder_layers=2,
            n_attention_heads=4,
            attention_dropout=0.0,
            mode='train'):

    input_encoder = input_encoder_fn(input_vocab_size, d_model,
                                     n_encoder_layers)
    pre_attention_decoder = pre_attention_decoder_fn(mode, target_vocab_size,
                                                     d_model)

    model = tl.Serial(
        tl.Select([0, 1, 0, 1]),
        tl.Parallel(input_encoder, pre_attention_decoder),
        tl.Fn('PrepareAttentionInput', prepare_attention_input, n_out=4),

        # nest it inside a Residual layer to add to the pre-attention decoder activations(i.e. queries)
        tl.Residual(
            tl.AttentionQKV(d_model,
                            n_heads=n_attention_heads,
                            dropout=attention_dropout,
                            mode=mode)),

        # Step 6: drop attention mask (i.e. index = None
        tl.Select([0, 2]),
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax())
    return model
Example #3
0
def _FunnelBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes, mode,
                 ff_activation, pool_layer, pool_size, strides, separate_cls):
    """Internal funnel block. Returns a list of layers implementing it.

  The input is an activation tensor.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.
    pool_layer: Type of pooling layer used for downsampling;
        should be `tl.AvgPool` or `tl.MaxPool`.
    pool_size: Shape of window that gets reduced to a single vector value.
        If the layer inputs are :math:`n`-dimensional arrays, then `pool_size`
        must be a tuple of length :math:`n-2`.
    strides: Offsets from the location of one window to the locations of
        neighboring windows along each axis. If specified, must be a tuple of
        the same length as `pool_size`. If None, then offsets of 1 along each
        window axis, :math:`(1, ..., 1)`, will be used.
    separate_cls: If `True`, pooling in funnel blocks is not applied to
          embeddings of the first token (`cls` from BERT paper).
  Returns:
      A list of layers that maps (activations, mask) to (activations', mask).
  """
    pooling = PoolLayer(pool_layer, pool_size, strides, separate_cls)
    mask_pooling = MaskPool(pool_size, strides, separate_cls)

    attention = tl.AttentionQKV(d_model,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode)
    hidden_dropout = tl.Dropout(rate=dropout,
                                shared_axes=dropout_shared_axes,
                                mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # h, mask
        tl.LayerNorm(),  # h, mask
        tl.Branch(pooling, None),  # h', h, mask
        tl.Residual(
            tl.Select([0, 1, 1, 2]),  # h', h, h, mask
            attention,  # attn, mask
            tl.Parallel(None, mask_pooling),  # attn, mask'
            hidden_dropout  # attn, mask'
        ),  # funnel_activations, mask'
        tl.Residual(feed_forward)
    ]
Example #4
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_activations, mask, encoder_activiations) where
  the mask is created from the original input token IDs to prevent attending to
  the padding part of the encoder.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value
        when applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask.
        Sharing along batch and sequence axes (`dropout_shared_axes=(0,1)`) is
        a useful way to save memory and apply consistent masks to activation
        vectors at different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will
        pass all values through unaltered.
    ff_activation: Type of activation function at the end of each block; must
        be an activation-type subclass of `Layer`.

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True)

    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        tl.Residual(feed_forward  # vec_d masks vec_e
                    ),
    ]
Example #5
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, dropout_shared_axes,
                         mode, ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    dropout_shared_axes: axes on which to share dropout mask
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True)

    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout,
                                     dropout_shared_axes, mode, ff_activation)

    return [  # vec_d masks vec_e
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        ResidualZero(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        ResidualZero(
            tl.LayerNorm(),
            feed_forward,  # vec_d masks vec_e
            _Dropout(),
        ),
    ]
Example #6
0
def _EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                         ff_activation):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout, mode=mode)

    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode)

    basic_causal_attention = tl.BasicCausalAttention(d_model,
                                                     n_heads=n_heads,
                                                     dropout=dropout,
                                                     mode=mode)

    feed_forward = _FeedForwardBlock(d_model, d_ff, dropout, layer_idx, mode,
                                     ff_activation)

    return [  # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            basic_causal_attention,  # vec_d masks .....
            _Dropout(),  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        tl.Residual(feed_forward  # vec_d masks vec_e
                    ),
    ]
def EncoderDecoder(d_model, d_ff, n_heads, dropout, layer_idx, mode,
                   ff_activation):
    """Transformer encoder-decoder layer.

  The input is a triple (decoder_input, mask, encoder) where the mask is
  created from the original source to prevent attending to the padding part
  of the encoder.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    layer_idx: which layer are we at (for bookkeeping)
    mode: str: 'train' or 'eval'
    ff_activation: the non-linearity in feed-forward layer

  Returns:
    the layer, returning a triple (decoder_activations, mask, encoder).
  """
    decoder_self_attention = [  #        vecs_d   pmask vecs_e
        tl.LayerNorm(),  #        vecs_d   ..... ......
        tl.BasicCausalAttention(d_model,
                                n_heads=n_heads,
                                dropout=dropout,
                                mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d          ..... ......
    ]
    decoder_to_encoder_attention = [  # vecs_d        masks         vecs_e
        tl.LayerNorm(),  # vecs_d        masks         vecs_e
        tl.Parallel([], [], tl.Dup()),  # ______        _____  vecs_e vecs_e
        tl.Parallel([], tl.Swap()),  # ______        vecs_e masks  ......
        tl.Parallel([], tl.Dup()),  # ______ vecs_e vecs_e .....  ......
        tl.AttentionQKV(  # (q k v masks ... --> vecs_d masks ...)
            d_model,
            n_heads=n_heads,
            dropout=dropout,
            mode=mode),
        tl.Dropout(rate=dropout, mode=mode),  # vecs_d mask vecs_e
    ]
    feed_forward = [
        FeedForward(d_model, d_ff, dropout, layer_idx, mode, ff_activation),
    ]
    return tl.Serial(  # vecs_d masks vecs_e
        tl.Residual(decoder_self_attention),  # vecs_d masks vecs_e
        tl.Residual(decoder_to_encoder_attention),  # vecs_d masks vecs_e
        tl.Residual(feed_forward),  # vecs_d masks vecs_e
    )
Example #8
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    pre_attention_qkv = [
        tl.LayerNorm(),
        tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
    ]
    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_attention_qkv = tl.Dropout(rate=dropout, mode=mode)

    pre_causal_attention = tl.LayerNorm()
    causal_attention = tl.CausalAttention(d_model, n_heads=n_heads, mode=mode)
    # TODO(kitaev): BroadcastedDropout?
    post_causal_attention = tl.Dropout(rate=dropout, mode=mode)

    feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation, mode)

    return [  # vec_d1 vec_d2 masks vec_e
        # TODO(kitaev): consider ReversibleAttentionHalfResidual for efficiency
        ReversibleHalfResidual(
            [pre_causal_attention, causal_attention, post_causal_attention]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(
            [pre_attention_qkv, attention_qkv, post_attention_qkv]),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),  # vec_d1 vec_d2 masks vec_e
    ]
Example #9
0
def LSTMSeq2SeqAttn(input_vocab_size=256,
                    target_vocab_size=256,
                    d_model=512,
                    n_encoder_layers=2,
                    n_decoder_layers=2,
                    n_attention_heads=1,
                    attention_dropout=0.0,
                    mode='train'):
  """Returns an LSTM sequence-to-sequence model with attention.

  This model is an encoder-decoder that performs tokenized string-to-string
  ("source"-to-"target") transduction:

    - inputs (2):

        - source: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(input_vocab_size)`, and `0`
          values mark padding positions.

        - target: rank 2 tensor representing a batch of text strings via token
          IDs plus padding markers; shape is (batch_size, sequence_length). The
          tensor elements are integers in `range(output_vocab_size)`, and `0`
          values mark padding positions.

    - output: rank 3 tensor representing a batch of log-probability
      distributions for each sequence position over possible token IDs;
      shape is (batch_size, sequence_length, `vocab_size`).

  An example use would be to translate (tokenized) sentences from English to
  German.

  The model works as follows:

  * Input encoder runs on the input tokens and creates activations that
    are used as both keys and values in attention.
  * Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.
  * Attention runs on the queries, keys and values masking out input padding.
  * Decoder runs on the result, followed by a cross-entropy loss.

  Args:
    input_vocab_size: Input vocabulary size -- each element of the input tensor
        should be an integer in `range(vocab_size)`. These integers typically
        represent token IDs from a vocabulary-based tokenizer.
    target_vocab_size: Target vocabulary size.
    d_model: Final dimension of tensors at most points in the model, including
        the initial embedding output.
    n_encoder_layers: Number of LSTM layers in the encoder.
    n_decoder_layers: Number of LSTM layers in the decoder after attention.
    n_attention_heads: Number of attention heads.
    attention_dropout: Stochastic rate (probability) for dropping an activation
        value when applying dropout within an attention block.
    mode: If `'predict'`, use fast inference. If `'train'`, each attention block
        will include dropout; else, it will pass all values through unaltered.

  Returns:
    An LSTM sequence-to-sequence model as a layer that maps from a
    source-target tokenized text pair to activations over a vocab set.
  """
  input_encoder = tl.Serial(
      tl.Embedding(input_vocab_size, d_model),
      [tl.LSTM(d_model) for _ in range(n_encoder_layers)],
  )

  pre_attention_decoder = tl.Serial(
      tl.ShiftRight(mode=mode),
      tl.Embedding(target_vocab_size, d_model),
      tl.LSTM(d_model),
  )

  def PrepareAttentionInputs():
    """Layer that prepares queries, keys, values and mask for attention."""
    def F(encoder_activations, decoder_activations, input_tokens):
      keys = values = encoder_activations
      queries = decoder_activations
      # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
      mask = (input_tokens != 0)
      # We need to add axes to the mask for attention heads and decoder length.
      mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
      # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
      mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
      mask = mask.astype(jnp.float32)
      return queries, keys, values, mask
    return tl.Fn('PrepareAttentionInputs', F, n_out=4)

  return tl.Serial(              # in-toks, target-toks
      tl.Select([0, 1, 0, 1]),   # in-toks, target-toks, in-toks, target-toks
      tl.Parallel(input_encoder, pre_attention_decoder),
      PrepareAttentionInputs(),  # q, k, v, mask, target-toks
      tl.Residual(
          tl.AttentionQKV(d_model, n_heads=n_attention_heads,
                          dropout=attention_dropout, mode=mode)
      ),                         # decoder-vecs, mask, target-toks
      tl.Select([0, 2]),         # decoder-vecs, target-toks
      [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
      tl.Dense(target_vocab_size),
      tl.LogSoftmax()
  )
Example #10
0
def EncoderDecoderBlock(d_model,
                        d_ff,
                        n_heads,
                        dropout,
                        dropout_shared_axes,
                        mode,
                        ff_activation,
                        ff_dropout,
                        ff_chunk_size,
                        ff_use_sru,
                        ff_sparsity,
                        ff_sparsity_type,
                        attention_chunk_size,
                        attention_type,
                        enc_dec_attention_sparsity=0):
    """Returns a list of layers implementing a Transformer encoder-decoder block.

  The input is a triple (decoder_activations, mask, encoder_activiations) where
  the mask is created from the original input token IDs to prevent attending to
  the padding part of the encoder.

  Args:
    d_model: Final dimension of tensors at most points in the model, including
      the initial embedding output.
    d_ff: Size of special dense layer in the feed-forward part of each block.
    n_heads: Number of attention heads.
    dropout: Stochastic rate (probability) for dropping an activation value when
      applying dropout within a block.
    dropout_shared_axes: Tensor axes on which to share a dropout mask. Sharing
      along batch and sequence axes (`dropout_shared_axes=(0,1)`) is a useful
      way to save memory and apply consistent masks to activation vectors at
      different sequence positions.
    mode: If `'train'`, each block will include dropout; else, it will pass all
      values through unaltered.
    ff_activation: Type of activation function at the end of each block; must be
      an activation-type subclass of `Layer`.
    ff_dropout: Stochastic rate (probability) for dropping an activation value
      when applying dropout after the FF dense layer.
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_use_sru: int or pair of ints; if > 0, we use this many SRU layers
      in addition to the feed-forward block (second int specifies sru size)
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity
     ff_sparsity_type: string, if ff_sparsity >0,
      use SparseFF if ff_sparsity_type=`'1inN'` and
      use BlockSparseFF if ff_sparsity_type=`'Block'`
    attention_chunk_size: int, if > 0 run attention chunked at this size
    attention_type: The attention layer to use.
    enc_dec_attention_sparsity: Sparsity to use in encoder-decoder attention.

  Returns:
    A list of layers which maps triples (decoder_activations, mask,
    encoder_activations) to triples of the same sort.
  """
    def _Dropout():
        return tl.Dropout(rate=dropout,
                          shared_axes=dropout_shared_axes,
                          mode=mode)

    # TODO(afrozm): This layer isn't configurable because: We currently don't have
    # any alternative for it (LSH cannot do it fundamentally, that's why we have
    # NoEncDec models, and local attention doesn't make sense in the general
    # setting where we don't know what in input is local to what in output;
    # some variants of FAVOR can do it, so maybe in the future,
    # but we don't have them yet).
    if isinstance(enc_dec_attention_sparsity, tuple):
        q_sparsity, result_sparsity = enc_dec_attention_sparsity
    elif enc_dec_attention_sparsity > 0:
        q_sparsity = enc_dec_attention_sparsity
        result_sparsity = 'noop'  # We simply skip Dense layer after attention.
    else:
        q_sparsity = None
        result_sparsity = None
    attention_qkv = tl.AttentionQKV(d_model,
                                    n_heads=n_heads,
                                    dropout=dropout,
                                    mode=mode,
                                    cache_KV_in_predict=True,
                                    q_sparsity=q_sparsity,
                                    result_sparsity=result_sparsity)

    causal_attention = ApplyAttentionLayer(
        attention_type,
        d_model,
        n_heads,
        d_model // n_heads,
        d_model // n_heads,
        causal=True,
        masked=True,
        attention_dropout=dropout,
        output_dropout=dropout,
        attention_chunk_size=attention_chunk_size,
        mode=mode)

    feed_forward = FeedForwardWithOptions(d_model, d_ff, dropout,
                                          dropout_shared_axes, ff_activation,
                                          ff_dropout, ff_chunk_size,
                                          ff_use_sru, ff_sparsity, mode, False,
                                          ff_sparsity_type)

    return [  # vec_d masks vec_e
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            causal_attention,  # vec_d ..... .....
            _Dropout(),  # vec_d ..... .....
        ),
        tl.Residual(
            tl.LayerNorm(),  # vec_d ..... .....
            tl.Select([0, 2, 2, 1, 2]),  # vec_d vec_e vec_e masks vec_e
            attention_qkv,  # vec_d masks vec_e
            _Dropout(),  # vec_d masks vec_e
        ),
        tl.Residual(feed_forward  # vec_d masks vec_e
                    ),
    ]
Example #11
0
 def _AttentionQKV():
     return tl.AttentionQKV(d_model,
                            n_heads=n_heads,
                            dropout=dropout,
                            mode=mode,
                            cache_KV_in_predict=True)
Example #12
0
File: rnn.py Project: yangliuy/trax
def LSTMSeq2SeqAttn(input_vocab_size=256,
                    target_vocab_size=256,
                    d_model=512,
                    n_encoder_layers=2,
                    n_decoder_layers=2,
                    n_attention_heads=1,
                    attention_dropout=0.0,
                    mode='train'):
    """Returns an LSTM sequence-to-sequence model with attention.

  The input to the model is a pair (input tokens, target tokens), e.g.,
  an English sentence (tokenized) and its translation into German (tokenized).

  The model works as follows:

  * Input encoder runs on the input tokens and creates activations that
    are used as both keys and values in attention.
  * Pre-attention decoder runs on the targets and creates
    activations that are used as queries in attention.
  * Attention runs on the queries, keys and values masking out input padding.
  * Decoder runs on the result, followed by a cross-entropy loss.

  Args:
    input_vocab_size: int: vocab size of the input
    target_vocab_size: int: vocab size of the target
    d_model: int:  depth of embedding (n_units in the LSTM cell)
    n_encoder_layers: int: number of LSTM layers in the encoder
    n_decoder_layers: int: number of LSTM layers in the decoder after attention
    n_attention_heads: int: number of attention heads
    attention_dropout: float, dropout for the attention layer
    mode: str: 'train', 'eval' or 'predict', predict mode is for fast inference

  Returns:
    An LSTM sequence-to-sequence model with attention.
  """
    input_encoder = tl.Serial(
        tl.Embedding(input_vocab_size, d_model),
        [tl.LSTM(d_model) for _ in range(n_encoder_layers)],
    )

    pre_attention_decoder = tl.Serial(
        tl.ShiftRight(mode=mode),
        tl.Embedding(target_vocab_size, d_model),
        tl.LSTM(d_model),
    )

    def PrepareAttentionInputs():
        """Layer that prepares queries, keys, values and mask for attention."""
        def F(encoder_activations, decoder_activations, input_tokens):
            keys = values = encoder_activations
            queries = decoder_activations
            # Mask is 1 where inputs are not padding (0) and 0 where they are padding.
            mask = (input_tokens != 0)
            # We need to add axes to the mask for attention heads and decoder length.
            mask = jnp.reshape(mask, (mask.shape[0], 1, 1, mask.shape[1]))
            # Broadcast so mask is [batch, 1 for heads, decoder-len, encoder-len].
            mask = mask + jnp.zeros((1, 1, decoder_activations.shape[1], 1))
            return queries, keys, values, mask

        return tl.Fn('PrepareAttentionInputs', F, n_out=4)

    return tl.Serial(  # in-toks, target-toks
        tl.Select([0, 1, 0, 1]),  # in-toks, target-toks, in-toks, target-toks
        tl.Parallel(input_encoder, pre_attention_decoder),
        PrepareAttentionInputs(),  # q, k, v, mask, target-toks
        tl.Residual(
            tl.AttentionQKV(d_model,
                            n_heads=n_attention_heads,
                            dropout=attention_dropout,
                            mode=mode)),  # decoder-vecs, mask, target-toks
        tl.Select([0, 2]),  # decoder-vecs, target-toks
        [tl.LSTM(d_model) for _ in range(n_decoder_layers)],
        tl.Dense(target_vocab_size),
        tl.LogSoftmax())