Ejemplo n.º 1
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_type, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_type: class: attention class to use, such as DotProductAttention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_key)
            ],
            [
                tl.ComputeAttentionHeads(n_heads=n_heads,
                                         d_head=d_attention_value)
            ],
        ),
    ]

    attention = attention_type(mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        tl.ComputeAttentionOutput(n_heads=n_heads, d_model=d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        BroadcastedDropout(rate=dropout, mode=mode),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_model, d_ff, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]
Ejemplo n.º 2
0
def DecoderBlock(d_model, d_ff, d_attention_key, d_attention_value, n_heads,
                 n_attention_chunks, attention_loop_stride, dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    d_attention_key: int: depth of key vector for each attention head
    d_attention_value: int: depth of value vector for each attention head
    n_heads: int: number of attention heads
    n_attention_chunks: int: number of chunks for attention
    attention_loop_stride: int: number of query elements to compute attention
      for in parallel. Set to 0 to disable memory-efficient attention.
    dropout: float: dropout rate (how much to drop out)
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """

    pre_attention = [
        Chunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(),
        tl.Dup(),
        tl.Dup(),
        tl.Parallel(
            [tl.Dense(d_attention_key * n_heads),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [tl.Dense(d_attention_key * n_heads),
             SplitHeads(n_heads=n_heads)],  # pylint: disable=no-value-for-parameter
            [
                tl.Dense(d_attention_value * n_heads),
                SplitHeads(n_heads=n_heads)
            ],  # pylint: disable=no-value-for-parameter
        ),
    ]

    # TODO(kitaev): add dropout
    if attention_loop_stride < 1:
        # Use the standard implementation if no loop_stride is provided.
        attention = DotProductAttention(dropout=None, mode=mode)
    else:
        attention = MemoryEfficientDotProductAttention(
            loop_stride=attention_loop_stride, dropout=None, mode=mode)

    # ReversibleAttentionHalfResidual requires that post_attention be linear in
    # its input (so the backward pass can be computed without knowing the input)
    post_attention = [
        JoinHeads(),  # pylint: disable=no-value-for-parameter
        tl.Dense(d_model),
        Unchunk(n_sections=n_attention_chunks),  # pylint: disable=no-value-for-parameter
    ]

    feed_forward = [
        FeedForward(d_model, d_ff, dropout, mode=mode),
    ]
    return [
        ReversibleAttentionHalfResidual(pre_attention, attention,
                                        post_attention),
        tl.ReversibleSwap(),
        ReversibleHalfResidual(feed_forward),
        tl.ReversibleSwap(),
    ]