예제 #1
0
파일: rse.py 프로젝트: yliu45/trax
def ResidualSwitchUnit(
    d_model, dropout=0.1, mode='train', residual_weight=0.9):
  r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf.

  As defined in the paper:

  .. math::
    i &= [i_1, i_2] \\
    g &= GELU(LayerNorm(Z i)) \\
    c &= W g + B \\
    [o_1, o_2] &= \sigma(S) \bigodot i + h \bigodot c

  where Z, W, B, S are learnable parameters with sizes 2m × 4m, 4m × 2m, 2m, 2m.
  We assume that both i_1 and i_2 have size m. h is a scalar value.

  We assume the input is of shape [batch, length, depth].

  Args:
    d_model: output depth of the SRU layer
    dropout: dropout rate used in 'train' mode
    mode: mode for dropout layer
    residual_weight: value used in initializing vector S and constant h

  Returns:
    The RSU layer.
  """
  return tl.Serial(
      tl.Fn(
          'Reshape2Pairs',
          lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)),
          n_out=1),
      tl.Residual(
          tl.Dense(4 * d_model, use_bias=False),
          tl.LayerNorm(),
          tl.Gelu(),
          tl.Dense(2 * d_model),
          tl.Fn('Scaling',
                lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25,
                n_out=1),
          shortcut=_ClippedScaling(residual_weight)),
      tl.Fn(
          'UnPair',
          lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)),
          n_out=1),
      tl.Dropout(rate=dropout, mode=mode)
      )
예제 #2
0
def BERT(d_model=768,
         vocab_size=30522,
         max_len=512,
         type_vocab_size=2,
         n_heads=12,
         d_ff=3072,
         n_layers=12,
         head=None,
         init_checkpoint=None,
         mode='eval',
        ):
  """BERT (default hparams are for bert-base-uncased)."""
  layer_norm_eps = 1e-12
  d_head = d_model // n_heads

  word_embeddings = tl.Embedding(d_model, vocab_size)
  type_embeddings = tl.Embedding(d_model, type_vocab_size)
  position_embeddings = tl.PositionalEncoding(max_len, mode=mode)
  embeddings = [
      tl.Select([0, 1, 0], n_in=3),  # Drops 'idx' input.
      tl.Parallel(
          word_embeddings,
          type_embeddings,
          [tl.PaddingMask(),
           tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)]
      ),
      tl.Add(),
      position_embeddings,
      tl.LayerNorm(epsilon=layer_norm_eps),
  ]

  encoder = []
  for _ in range(n_layers):
    attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head,
                            bias=True, masked=True, mode=mode)
    feed_forward = [
        tl.Dense(d_ff),
        tl.Gelu(),
        tl.Dense(d_model)
    ]
    encoder += [
        tl.Select([0, 1, 1]),  # Save a copy of the mask
        tl.Residual(attn, AddBias()),  # pylint: disable=no-value-for-parameter
        tl.LayerNorm(epsilon=layer_norm_eps),
        tl.Residual(*feed_forward),
        tl.LayerNorm(epsilon=layer_norm_eps),
    ]

  encoder += [tl.Select([0], n_in=2)]  # Drop the mask

  pooler = [
      tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2),
      tl.Dense(d_model),
      tl.Tanh(),
  ]

  init_checkpoint = init_checkpoint if mode == 'train' else None
  bert = PretrainedBERT(
      embeddings + encoder + pooler, init_checkpoint=init_checkpoint)

  if head is not None:
    bert = tl.Serial(bert, head())

  return bert