def ResidualSwitchUnit( d_model, dropout=0.1, mode='train', residual_weight=0.9): r"""RSU (Residual Switch Unit) layer as in https://arxiv.org/pdf/2004.04662.pdf. As defined in the paper: .. math:: i &= [i_1, i_2] \\ g &= GELU(LayerNorm(Z i)) \\ c &= W g + B \\ [o_1, o_2] &= \sigma(S) \bigodot i + h \bigodot c where Z, W, B, S are learnable parameters with sizes 2m × 4m, 4m × 2m, 2m, 2m. We assume that both i_1 and i_2 have size m. h is a scalar value. We assume the input is of shape [batch, length, depth]. Args: d_model: output depth of the SRU layer dropout: dropout rate used in 'train' mode mode: mode for dropout layer residual_weight: value used in initializing vector S and constant h Returns: The RSU layer. """ return tl.Serial( tl.Fn( 'Reshape2Pairs', lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] // 2, -1)), n_out=1), tl.Residual( tl.Dense(4 * d_model, use_bias=False), tl.LayerNorm(), tl.Gelu(), tl.Dense(2 * d_model), tl.Fn('Scaling', lambda x: x * np.sqrt(1 - residual_weight**2) * 0.25, n_out=1), shortcut=_ClippedScaling(residual_weight)), tl.Fn( 'UnPair', lambda x: jnp.reshape(x, (x.shape[0], x.shape[1] * 2, -1)), n_out=1), tl.Dropout(rate=dropout, mode=mode) )
def BERT(d_model=768, vocab_size=30522, max_len=512, type_vocab_size=2, n_heads=12, d_ff=3072, n_layers=12, head=None, init_checkpoint=None, mode='eval', ): """BERT (default hparams are for bert-base-uncased).""" layer_norm_eps = 1e-12 d_head = d_model // n_heads word_embeddings = tl.Embedding(d_model, vocab_size) type_embeddings = tl.Embedding(d_model, type_vocab_size) position_embeddings = tl.PositionalEncoding(max_len, mode=mode) embeddings = [ tl.Select([0, 1, 0], n_in=3), # Drops 'idx' input. tl.Parallel( word_embeddings, type_embeddings, [tl.PaddingMask(), tl.Fn('Squeeze', lambda x: np.squeeze(x, (1, 2)), n_out=1)] ), tl.Add(), position_embeddings, tl.LayerNorm(epsilon=layer_norm_eps), ] encoder = [] for _ in range(n_layers): attn = tl.SelfAttention(n_heads=n_heads, d_qk=d_head, d_v=d_head, bias=True, masked=True, mode=mode) feed_forward = [ tl.Dense(d_ff), tl.Gelu(), tl.Dense(d_model) ] encoder += [ tl.Select([0, 1, 1]), # Save a copy of the mask tl.Residual(attn, AddBias()), # pylint: disable=no-value-for-parameter tl.LayerNorm(epsilon=layer_norm_eps), tl.Residual(*feed_forward), tl.LayerNorm(epsilon=layer_norm_eps), ] encoder += [tl.Select([0], n_in=2)] # Drop the mask pooler = [ tl.Fn('', lambda x: (x[:, 0, :], x), n_out=2), tl.Dense(d_model), tl.Tanh(), ] init_checkpoint = init_checkpoint if mode == 'train' else None bert = PretrainedBERT( embeddings + encoder + pooler, init_checkpoint=init_checkpoint) if head is not None: bert = tl.Serial(bert, head()) return bert