def _init_mixing_sublayer(self, layer, model_arch, mixing_key): """Initializes config-dependent mixing sublayer.""" if model_arch == ModelArchitecture.BERT: mixing_sublayer = nn.SelfAttention( num_heads=self.config.num_heads, qkv_features=self.config.d_model, broadcast_dropout=False, kernel_init=default_kernel_init, bias_init=default_bias_init, dropout_rate=self.config.mixing_dropout_rate, use_bias=True, name=f"self_attention_{layer}") elif model_arch == ModelArchitecture.F_NET: mixing_sublayer = layers.FourierTransform( fourier_transform=self.fourier_transform, name=f"fourier_transform_{layer}") elif model_arch == ModelArchitecture.FF_ONLY: mixing_sublayer = layers.IdentityTransform( name=f"identity_transform_{layer}") elif model_arch == ModelArchitecture.LINEAR: mixing_sublayer = layers.LinearTransform( precision=lax.Precision.DEFAULT, name=f"linear_transform_{layer}") elif model_arch == ModelArchitecture.RANDOM: mixing_sublayer = layers.RandomTransform( max_seq_length=self.config.max_seq_length, d_model=self.config.d_model, key=mixing_key, precision=lax.Precision.DEFAULT, name=f"random_transform_{layer}") else: raise ValueError("Unexpected model architecture: %s" % model_arch.name) return mixing_sublayer
def test_encoder_block(self): batch_size = 2 max_seq_length = 14 hidden_dim = 8 rng = jax.random.PRNGKey(0) feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0) mixing_layer = layers.IdentityTransform() encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=mixing_layer) init_batch = { "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32), "padding_mask": jnp.ones((1, max_seq_length), jnp.int32) } params = init_layer_variables(rng, encoder_block, init_batch)["params"] expected_keys = { "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer" } self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) inputs = { "inputs": jax.random.randint(init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10), "padding_mask": jax.random.randint(init_rng, (batch_size, max_seq_length), minval=0, maxval=1) } outputs = encoder_block.apply({"params": params}, rngs={"dropout": rng}, **inputs) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_identity_transform(self): batch_size = 8 max_seq_length = 8 hidden_dim = 8 rng = jax.random.PRNGKey(0) identity_layer = layers.IdentityTransform() init_batch = { "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32) } init_layer_variables(rng, identity_layer, init_batch) rng, init_rng = jax.random.split(rng) inputs = jax.random.randint(init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) # IdentityTransform layer has no learnable params. outputs = identity_layer.apply({"params": {}}, inputs=inputs) # Inputs are unchanged by IdentityTransform layer. np.testing.assert_allclose(outputs, inputs)