def setup(self): self.word_embeddings = nn.Embed( num_embeddings=self.config.vocab_size, features=self.config.hidden_size, embedding_init=get_kernel_init(self.config), name="word_embeddings", ) self.position_embeddings = layers.PositionalEncoding( num_embeddings=self.config.max_position_embeddings, features=self.config.hidden_size, embedding_init=get_kernel_init(self.config), name="position_embeddings", ) self.type_embeddings = nn.Embed( num_embeddings=self.config.type_vocab_size, features=self.config.hidden_size, embedding_init=get_kernel_init(self.config), name="type_embeddings", ) self.embeddings_layer_norm = nn.LayerNorm( epsilon=self.config.layer_norm_eps, name="embeddings_layer_norm") self.embeddings_dropout = nn.Dropout( rate=self.config.hidden_dropout_prob) build_feed_forward = functools.partial( layers.FeedForward, d_model=self.config.hidden_size, d_ff=self.config.intermediate_size, intermediate_activation=get_hidden_activation(self.config), kernel_init=get_kernel_init(self.config), ) build_self_attention = functools.partial( layers.SelfAttention, num_heads=self.config.num_attention_heads, qkv_features=self.config.hidden_size, dropout_rate=self.config.attention_probs_dropout_prob, broadcast_dropout=False, kernel_init=get_kernel_init(self.config), bias_init=nn.initializers.zeros, ) self.encoder_layers = [ layers.TransformerBlock( build_feed_forward=build_feed_forward, build_self_attention=build_self_attention, dropout_rate=self.config.hidden_dropout_prob, layer_norm_epsilon=self.config.layer_norm_eps, name=f"encoder_layer_{layer_num}", ) for layer_num in range(self.config.num_hidden_layers) ] self.pooler = nn.Dense( kernel_init=get_kernel_init(self.config), name="pooler", features=self.config.hidden_size, )
def apply(self, input_ids, input_mask, type_ids, *, config, deterministic=False): """Applies BERT model on the inputs.""" word_embeddings = nn.Embed(input_ids, num_embeddings=config.vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='word_embeddings') position_embeddings = layers.PositionalEncoding( word_embeddings, max_len=config.max_position_embeddings, posemb_init=get_kernel_init(config), name='position_embeddings') type_embeddings = nn.Embed(type_ids, num_embeddings=config.type_vocab_size, features=config.hidden_size, embedding_init=get_kernel_init(config), name='type_embeddings') embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm(embeddings, epsilon=LAYER_NORM_EPSILON, name='embeddings_layer_norm') embeddings = nn.dropout(embeddings, rate=config.hidden_dropout_prob, deterministic=deterministic) # Transformer blocks feed_forward = layers.FeedForward.partial( d_ff=config.intermediate_size, dropout_rate=config.hidden_dropout_prob, intermediate_activation=get_hidden_activation(config), kernel_init=get_kernel_init(config)) attention = efficient_attention.BertSelfAttention.partial( num_heads=config.num_attention_heads, num_parallel_heads=None, d_qkv=config.hidden_size // config.num_attention_heads, attention_dropout_rate=config.attention_probs_dropout_prob, output_dropout_rate=config.hidden_dropout_prob, kernel_init=get_kernel_init(config), output_kernel_init=get_kernel_init(config)) hidden_states = embeddings mask = input_mask.astype(jnp.int32) for layer_num in range(config.num_hidden_layers): hidden_states = layers.TransformerBlock( hidden_states, mask, feed_forward=feed_forward, attention=attention, deterministic=deterministic, name=f'encoder_layer_{layer_num}') pooled_output = nn.Dense(hidden_states[:, 0], config.hidden_size, kernel_init=get_kernel_init(config), name='pooler') pooled_output = jnp.tanh(pooled_output) return hidden_states, pooled_output