def test_feed_forward_layer(self): batch_size = 3 max_seq_length = 16 hidden_dim = 12 rng = jax.random.PRNGKey(0) feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.1) init_batch = { "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32) } params = init_layer_variables(rng, feed_forward_layer, init_batch)["params"] expected_keys = {"intermediate", "output"} self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) inputs = jax.random.randint(init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) outputs = feed_forward_layer.apply({"params": params}, rngs={"dropout": rng}, inputs=inputs) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def setup(self): """Initializes encoder with config-dependent mixing layer.""" if self.config.model_arch == ModelArchitecture.F_NET: self._init_fourier_transform() # Random number generator key for RANDOM model architecture. key = random.PRNGKey(self.random_seed) encoder_blocks = [] # Attributes are immutable so use temporary list for layer in range(self.config.num_layers): key, mixing_key = random.split(key) mixing_arch = ModelArchitecture.BERT if self._is_attention_layer( layer) else self.config.model_arch mixing_layer = self._init_mixing_sublayer(layer, mixing_arch, mixing_key) feed_forward_layer = layers.FeedForwardLayer( d_ff=self.config.d_ff, dropout_rate=self.config.dropout_rate, name=f"feed_forward_{layer}") encoder_blocks.append( layers.EncoderBlock( mixing_sublayer=mixing_layer, feed_forward_sublayer=feed_forward_layer, name=f"encoder_{layer}")) self.encoder_blocks = encoder_blocks self.embedder = layers.EmbeddingLayer(config=self.config, name="embedder") self.pooler = nn.Dense( self.config.d_model, kernel_init=default_kernel_init, name="pooler")
def test_encoder_block(self): batch_size = 2 max_seq_length = 14 hidden_dim = 8 rng = jax.random.PRNGKey(0) feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0) mixing_layer = layers.IdentityTransform() encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=mixing_layer) init_batch = { "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32), "padding_mask": jnp.ones((1, max_seq_length), jnp.int32) } params = init_layer_variables(rng, encoder_block, init_batch)["params"] expected_keys = { "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer" } self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) inputs = { "inputs": jax.random.randint(init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10), "padding_mask": jax.random.randint(init_rng, (batch_size, max_seq_length), minval=0, maxval=1) } outputs = encoder_block.apply({"params": params}, rngs={"dropout": rng}, **inputs) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))