Exemplo n.º 1
0
def AttentionPosition(vec,
                      pos,
                      positions=None,
                      d_model=None,
                      n_heads=8,
                      dropout=0.0,
                      mode='train'):
    """Transformer-style multi-headed attention."""

    new_posns = list(
        LearnedPosOperations(positions=positions, n_combinations=n_heads)
        @ (vec, pos))

    hq = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads, tile=False)) @ ([
        vec,
    ] + new_posns)
    hk = tl.Serial(tl.Dense(d_model), CopyPosToHeads(n_heads,
                                                     tile=True)) @ (vec, pos)
    hv = tl.ComputeAttentionHeads(n_heads=n_heads,
                                  d_head=d_model // n_heads) @ vec

    x, pos = tl.Serial(
        tl.DotProductCausalAttention(dropout=dropout, mode=mode),
        CombineHeadsPos(n_heads=n_heads), tl.Dense(d_model)) @ (hq, hk, hv)

    return x, pos
Exemplo n.º 2
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, n_attention_chunks, attention_type,
                 dropout, share_qk, ff_activation, ff_use_sru, ff_chunk_size,
                 mode):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    ff_activation: the non-linearity in feed-forward layer
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if not hasattr(attention_type, 'forward_unbatched'):
    if share_qk:
      pre_attention = [
          Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
          tl.LayerNorm(),
          tl.Dup(),
          tl.Parallel(
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(
                  n_heads=n_heads, d_head=d_attention_value),
          ),
          tl.Dup(),
      ]
    else:
      pre_attention = [
          Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
          tl.LayerNorm(),
          tl.Dup(), tl.Dup(),
          tl.Parallel(
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
              tl.ComputeAttentionHeads(
                  n_heads=n_heads, d_head=d_attention_value),
          ),
      ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    attention_half_residual = ReversibleAttentionHalfResidual(
        pre_attention, attention, post_attention)
  else:
    attention = attention_type(
        n_heads=n_heads, d_qk=d_attention_key, d_v=d_attention_value,
        share_qk=share_qk, causal=True, output_dropout=dropout, mode=mode)
    attention_half_residual = ReversibleHalfResidualV2(
        tl.LayerNorm(),
        attention_layer=attention,
    )

  if ff_use_sru:
    feed_forward = [tl.SRU(d_model) for _ in range(ff_use_sru)]
  else:
    feed_forward = [ChunkedFeedForward(d_model, d_ff, dropout, ff_activation,
                                       dropout, ff_chunk_size, mode)]

  return [
      attention_half_residual,
      tl.ReversibleSwap(),
      ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Exemplo n.º 3
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value,
                 n_heads, n_attention_chunks, attention_type,
                 dropout, share_qk, mode):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: subclass of tl.BaseCausalAttention: attention class to use
    dropout: float: dropout rate (how much to drop out)
    share_qk: string, whether to share queries and keys
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
  if share_qk:
    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Parallel(
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value),
        ),
        tl.Dup(),
    ]
  else:
    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(), tl.Dup(),
        tl.Parallel(
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_key),
            tl.ComputeAttentionHeads(n_heads=n_heads, d_head=d_attention_value),
        ),
    ]

  attention = attention_type(mode=mode)

  # ReversibleAttentionHalfResidual requires that post_attention be linear in
  # its input (so the backward pass can be computed without knowing the input)
  post_attention = [
      tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
      Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
      BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
  ]

  feed_forward = [
      FeedForward(d_model, d_ff, dropout, mode=mode),
  ]
  return [
      ReversibleAttentionHalfResidual(pre_attention, attention, post_attention),
      tl.ReversibleSwap(),
      ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]