def test_hybrid_encoder(self, attention_layout, num_attention_layers, expected_attention_layers): config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.num_layers = 4 config.attention_layout = attention_layout config.num_attention_layers = num_attention_layers frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) expected_keys = { "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3", "feed_forward_0", "feed_forward_1", "feed_forward_2", "feed_forward_3", "pooler" } for expected_attention_layer in expected_attention_layers: expected_keys.add(f"self_attention_{expected_attention_layer}") self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
def test_unparametrized_mixing_encoder(self, model_arch): config = dummy_config(model_arch=model_arch) frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) # Unparameterized mixing encoders do not have any parameters in their mixing # layers, so their mixing layer names do not show up in params. expected_keys = { "embedder", "encoder_0", "encoder_1", "feed_forward_0", "feed_forward_1", "pooler" } self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
def test_f_net_encoder_bad_long_seq(self): config = dummy_frozen_config( model_arch=base_config.ModelArchitecture.F_NET, use_tpu_fourier_optimizations=True, max_seq_length=8194) encoder = models.EncoderModel(config=config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) with self.assertRaisesRegex( ValueError, "must be a power of 2 to take advantage of FFT optimizations"): _ = init_model_params(rng, encoder, init_batch)
def test_f_net_encoder_bad_long_seq(self): config = dummy_config(model_arch=ModelArchitecture.F_NET) with config.unlocked(): config.max_seq_length = 8194 frozen_config = ml_collections.FrozenConfigDict(config) encoder = models.EncoderModel(config=frozen_config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) with self.assertRaisesRegex( ValueError, "must be a power of 2 to take advantage of FFT optimizations"): _ = init_model_params(rng, encoder, init_batch)
def test_bert_encoder(self): config = dummy_frozen_config(model_arch=base_config.ModelArchitecture.BERT) encoder = models.EncoderModel(config=config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) expected_keys = { "embedder", "encoder_0", "encoder_1", "feed_forward_0", "feed_forward_1", "self_attention_0", "self_attention_1", "pooler" } self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
def create_optimizer(key, config): """Creates optimizer for models.EncoderModel.""" model = models.EncoderModel(config=config) init_batch = { "input_ids": jnp.ones((1, config.max_seq_length), jnp.int32), "input_mask": jnp.ones((1, config.max_seq_length), jnp.int32), "type_ids": jnp.ones((1, config.max_seq_length), jnp.int32) } key, dropout_key = jax.random.split(key) jit_init = jax.jit(model.init) initial_variables = jit_init({ "params": key, "dropout": dropout_key }, **init_batch) params = initial_variables["params"] optimizer_def = optim.Adam(learning_rate=1e-4) return optimizer_def.create(params)
def test_f_net_encoder_short_seq(self): config = dummy_frozen_config( model_arch=base_config.ModelArchitecture.F_NET, max_seq_length=16) encoder = models.EncoderModel(config=config) rng = jax.random.PRNGKey(0) init_batch = init_encoder_batch(config) params = init_model_params(rng, encoder, init_batch) # Fourier sublayers have no parameters so do not show up in params. expected_keys = { "embedder", "encoder_0", "encoder_1", "feed_forward_0", "feed_forward_1", "pooler" } self.assertEqual(params.keys(), expected_keys) inputs = dummy_inputs(rng, config) hidden_states, pooled_output = encoder.apply({"params": params}, rngs={"dropout": rng}, **inputs) expected_hidden_states_shape = (config.train_batch_size, config.max_seq_length, config.d_model) self.assertEqual(hidden_states.shape, expected_hidden_states_shape) expected_pooled_output_shape = (config.train_batch_size, config.d_model) self.assertEqual(pooled_output.shape, expected_pooled_output_shape)