Ejemplo n.º 1
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                        ff_dropout, mode, ff_use_sru=0, ff_chunk_size=0,
                        ff_sparsity=0):
  """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'
    ff_use_sru: int; if > 0, we use this many SRU layers instead of feed-forward
    ff_chunk_size: int; if > 0, chunk feed-forward into this-sized chunks
    ff_sparsity: int, if > 0 use sparse feed-forward block with this sparsity

  Returns:
    the layer.
  """
  enc_dec_attention = tl.EncDecAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  enc_dec_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=enc_dec_attention,
  )

  causal_attention = tl.SelfAttention(
      n_heads=n_heads, d_qk=d_model//n_heads, d_v=d_model//n_heads,
      causal=True,
      attention_dropout=dropout, output_dropout=dropout,
      mode=mode)
  causal_attention_half_residual = tl.ReversibleHalfResidual(
      tl.LayerNorm(),
      attention_layer=causal_attention,
  )

  feed_forward = ct.FeedForwardWithOptions(
      d_model, d_ff, dropout, [-2], ff_activation, ff_dropout,
      ff_chunk_size, ff_use_sru, ff_sparsity, mode)

  return [                             # vec_d1 vec_d2 vec_e masks
      causal_attention_half_residual,
      tl.ReversibleSwap(),
      enc_dec_attention_half_residual,
      tl.ReversibleSwap(),
      tl.ReversibleHalfResidual(feed_forward),
      tl.ReversibleSwap(),
  ]
Ejemplo n.º 2
0
def EncoderDecoderBlock(d_model, d_ff, n_heads, dropout, ff_activation,
                        ff_dropout, mode):
    """Reversible transformer decoder layer.

  Args:
    d_model: int:  depth of embedding
    d_ff: int: depth of feed-forward layer
    n_heads: int: number of attention heads
    dropout: float: dropout rate (how much to drop out)
    ff_activation: the non-linearity in feed-forward layer
    ff_dropout: float: (optional) separate dropout rate for feed-forward layer
    mode: str: 'train' or 'eval'

  Returns:
    the layer.
  """
    enc_dec_attention = tl.EncDecAttention(n_heads=n_heads,
                                           d_qk=d_model // n_heads,
                                           d_v=d_model // n_heads,
                                           attention_dropout=dropout,
                                           output_dropout=dropout,
                                           mode=mode)
    enc_dec_attention_half_residual = ReversibleHalfResidualV2(
        tl.LayerNorm(),
        attention_layer=enc_dec_attention,
    )

    causal_attention = tl.SelfAttention(n_heads=n_heads,
                                        d_qk=d_model // n_heads,
                                        d_v=d_model // n_heads,
                                        causal=True,
                                        attention_dropout=dropout,
                                        output_dropout=dropout,
                                        mode=mode)
    causal_attention_half_residual = ReversibleHalfResidualV2(
        tl.LayerNorm(),
        attention_layer=causal_attention,
    )

    feed_forward = FeedForward(d_model, d_ff, dropout, ff_activation,
                               ff_dropout, mode)

    return [  # vec_d1 vec_d2 vec_e masks
        causal_attention_half_residual,
        tl.ReversibleSwap(),
        enc_dec_attention_half_residual,
        tl.ReversibleSwap(),
        ReversibleHalfResidualV2(feed_forward),
        tl.ReversibleSwap(),
    ]