def ChunkedDecoderLayer(feature_depth,
                        feedforward_depth,
                        num_heads,
                        dropout,
                        chunk_selector,
                        mode):
  """Transformer decoder layer operating on chunks.

  Args:
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  return tl.Serial(
      tl.Residual(  # Self-attention block.
          tl.Map(tl.LayerNorm()),
          ChunkedCausalMultiHeadedAttention(
              feature_depth, num_heads=num_heads, dropout=dropout,
              chunk_selector=chunk_selector, mode=mode),
          tl.Map(tl.Dropout(rate=dropout, mode=mode)),
      ),
      tl.Map(ResidualFeedForward(
          feature_depth, feedforward_depth, dropout, mode=mode))
  )
Exemplo n.º 2
0
def ChunkedDecoderLayer(d_feature, d_feedforward, n_heads, dropout,
                        chunk_selector, mode):
    """Transformer decoder layer operating on chunks.

  Args:
    d_feature: int:  depth of embedding
    d_feedforward: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    The layers comprising a chunked decoder.
  """
    return [
        Residual(  # Self-attention block.
            tl.Map(tl.LayerNorm()),
            ChunkedCausalMultiHeadedAttention(d_feature,
                                              n_heads=n_heads,
                                              dropout=dropout,
                                              chunk_selector=chunk_selector,
                                              mode=mode),
            tl.Map(tl.Dropout(rate=dropout, mode=mode)),
        ),
        tl.Map(
            ResidualFeedForward(d_feature, d_feedforward, dropout, mode=mode))
    ]
Exemplo n.º 3
0
def ChunkedTransformerLM(vocab_size,
                         feature_depth=512,
                         feedforward_depth=2048,
                         num_layers=6,
                         num_heads=8,
                         dropout=0.1,
                         chunk_selector=None,
                         max_len=2048,
                         mode='train'):
    """Transformer language model operating on chunks.

  The input to this  model is a sequence presented as a list or tuple of chunks:
    (chunk1, chunk2, chunks3, ..., chunkN).
  Each chunk should have the same shape (batch, chunk-length) and together they
  represent a long sequence that's a concatenation chunk1,chunk2,...,chunkN.

  Chunked Transformer emulates the operation of a Transformer on this long
  sequence except for the chunked attention layer, which may attend to only
  a subset of the chunks to reduce memory use.

  Args:
    vocab_size: int: vocab size
    feature_depth: int:  depth of embedding
    feedforward_depth: int: depth of feed-forward layer
    num_layers: int: number of encoder/decoder layers
    num_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    chunk_selector: a function from chunk number to list of chunks to attend
      (if None, attends to the previous chunks which is equivalent to setting
       chunk_selector(x) = [] if x < 1 else [x-1] (TransformerXL); we attend
       to the current chunk with a causal mask too, selected chunks unmasked).
    max_len: int: maximum symbol length for positional encoding
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    stack = [
        ChunkedDecoderLayer(feature_depth, feedforward_depth, num_heads,
                            dropout, chunk_selector, mode)
        for _ in range(num_layers)
    ]
    # Below each Map(L) applies the layer L to each chunk independently.
    return tl.Serial(
        tl.ShiftRight(),
        tl.Map(tl.Embedding(feature_depth, vocab_size)),
        tl.Map(tl.Dropout(rate=dropout, mode=mode)),
        tl.PositionalEncoding(max_len=max_len),
        tl.Serial(*stack),
        tl.Map(tl.LayerNorm()),
        tl.Map(tl.Dense(vocab_size)),
        tl.Map(tl.LogSoftmax()),
    )
Exemplo n.º 4
0
def ChunkedCausalMultiHeadedAttention(d_feature,
                                      n_heads=8,
                                      dropout=0.0,
                                      chunk_selector=None,
                                      mode='train'):
    """Transformer-style causal multi-headed attention operating on chunks.

  Accepts inputs that are a list of chunks and applies causal attention.

  Args:
    d_feature: int:  depth of embedding
    n_heads: int: number of attention heads
    dropout: float: dropout rate
    chunk_selector: a function from chunk number to list of chunks to attend.
    mode: str: 'train' or 'eval'

  Returns:
    Multi-headed self-attention layer.
  """
    prepare_attention_input = tl.Serial(
        tl.Branch(
            tl.Branch(  # q = k = v = first input
                tl.NoOp(), tl.NoOp(), tl.NoOp()),
            tl.CausalMask(axis=-2),
        ),
        tl.Parallel(
            tl.Parallel(
                tl.Dense(d_feature),
                tl.Dense(d_feature),
                tl.Dense(d_feature),
            ), tl.NoOp()))
    return tl.Serial(
        tl.Map(prepare_attention_input),
        ChunkedAttentionSelector(selector=chunk_selector),  # pylint: disable=no-value-for-parameter
        tl.Map(tl.PureMultiHeadedAttention(d_feature=d_feature,
                                           n_heads=n_heads,
                                           dropout=dropout,
                                           mode=mode),
               check_shapes=False),
        tl.Map(tl.Select(0), check_shapes=False),  # drop masks
        tl.Map(tl.Dense(d_feature)))