def _init_feed_forward_sublayer(self, layer): """Initializes config-dependent feed-forward sublayer.""" if self._is_moe_layer(layer): expert = layers.FeedForwardLayer( d_ff=self.config.expert_d_ff, dropout_rate=self.config.expert_dropout_rate, dtype=layers.truncated_dtype(), name=f"expert_{layer}") ff_sublayer = layers.MoeLayer( num_experts=self.config.num_experts, router=self._init_routers(layer), max_group_size=self.config.max_group_size, train_capacity_factor=self.config.train_capacity_factor, eval_capacity_factor=self.config.eval_capacity_factor, expert=expert, min_expert_capacity=self.config.min_expert_capacity, dropout_rate=self.config.expert_dropout_rate, dtype=self.config.dtype, name=f"moe_{layer}") else: ff_sublayer = layers.FeedForwardLayer( d_ff=self.config.d_ff, dropout_rate=self.config.dropout_rate, dtype=self.config.dtype, name=f"feed_forward_{layer}") return ff_sublayer
def test_feed_forward_layer(self): batch_size = 3 max_seq_length = 16 hidden_dim = 12 rng = jax.random.PRNGKey(0) feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.1) init_batch = { "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32) } params = init_layer_variables(rng, feed_forward_layer, init_batch)["params"] expected_keys = {"intermediate", "output"} self.assertEqual(params.keys(), expected_keys) rng, init_rng = jax.random.split(rng) input_emb = jax.random.randint( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) outputs = feed_forward_layer.apply({"params": params}, rngs={"dropout": rng}, input_emb=input_emb) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_moe_layer_runs(self, dispatch): batch_size = 3 max_seq_length = 4 num_tokens = batch_size * max_seq_length hidden_dim = 2 num_experts = 4 rng = jax.random.PRNGKey(0) if dispatch == "mask": router = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) else: router = routing.TokensChooseScatterRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp") moe_layer = layers.MoeLayer(num_experts=num_experts, max_group_size=num_tokens, router=router, train_capacity_factor=1.5, eval_capacity_factor=1.5, expert=expert, axis_name="batch") init_batch = { "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32) } params = init_layer_variables(rng, moe_layer, init_batch)["params"] expected_keys = {"router", "expert"} self.assertEqual(params.keys(), expected_keys) dropout_rng, jitter_rng, init_rng = jax.random.split(rng, num=3) input_emb = jax.random.uniform( init_rng, (batch_size, max_seq_length, hidden_dim), minval=-10, maxval=10) actual_outputs, state = moe_layer.apply({"params": params}, rngs={ "dropout": dropout_rng, "jitter": jitter_rng }, mutable=["intermediates"], input_emb=input_emb) self.assertEqual(actual_outputs.shape, (batch_size, max_seq_length, hidden_dim)) self.assertIn("diversity_metrics", state["intermediates"])
def test_num_groups(self, max_group_size, num_tokens, num_experts, expected_num_groups): expert = layers.FeedForwardLayer(d_ff=2) router = routing.ExpertsChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., dtype=jnp.float32) moe_layer = layers.MoeLayer(num_experts=num_experts, router=router, max_group_size=max_group_size, train_capacity_factor=1., eval_capacity_factor=1., expert=expert) num_groups = moe_layer._num_groups(num_tokens, max_group_size) self.assertEqual(num_groups, expected_num_groups)
def test_construct_encoder_block_correctly(self): max_seq_length = 14 hidden_dim = 8 rng = jax.random.PRNGKey(0) feed_forward_layer = layers.FeedForwardLayer(d_ff=8) mixing_layer = layers.LinearTransform() attention_layer = layers.AttentionLayer(num_heads=1, d_model=2) init_batch = { "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32), "input_ids": jnp.ones((1, max_seq_length), jnp.int32), } # Success case. encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=mixing_layer, attention_sublayer=None) _ = init_layer_variables(rng, encoder_block, init_batch) # Failure case. with self.assertRaisesRegex( ValueError, "One, and only one, of {self.mixing_sublayer, " "self.attention_sublayer} must be nonempty"): encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=mixing_layer, attention_sublayer=attention_layer) _ = init_layer_variables(rng, encoder_block, init_batch) # Failure case. with self.assertRaisesRegex( ValueError, "One, and only one, of {self.mixing_sublayer, " "self.attention_sublayer} must be nonempty"): encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=None, attention_sublayer=None) _ = init_layer_variables(rng, encoder_block, init_batch)
def test_encoder_block_feed_forward(self): batch_size = 2 max_seq_length = 14 hidden_dim = 8 rng = jax.random.PRNGKey(0) rng, init_rng = jax.random.split(rng) feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0) mixing_layer = layers.LinearTransform() encoder_block = layers.EncoderBlock( feed_forward_sublayer=feed_forward_layer, mixing_sublayer=mixing_layer, attention_sublayer=None) input_emb = jax.random.uniform( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) input_ids = jax.random.randint( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=20) params = init_layer_variables(rng, encoder_block, { "input_emb": input_emb, "input_ids": input_ids })["params"] expected_keys = { "mixing_sublayer", "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer" } self.assertEqual(params.keys(), expected_keys) outputs = encoder_block.apply({"params": params}, rngs={"dropout": rng}, input_emb=input_emb, input_ids=input_ids) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim))
def test_scatter_mask_dispatch_equal(self): batch_size = 4 max_seq_length = 4 hidden_dim = 2 num_experts = 2 tokens_per_group = 8 num_groups = batch_size * max_seq_length // tokens_per_group rng = jax.random.PRNGKey(0) expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp") moe_layer_factory = functools.partial( layers.MoeLayer, num_experts=num_experts, max_group_size=tokens_per_group, train_capacity_factor=1., eval_capacity_factor=1., expert=expert, split_params=False) # Ensures all experts start with same params router_weights = routing.RouterWeights(name="router_weights") masked_router = routing.TokensChooseMaskedRouter( router_weights=router_weights, jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) masked_moe_layer = moe_layer_factory(router=masked_router) scatter_router = routing.TokensChooseScatterRouter( router_weights=router_weights, jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) scatter_moe_layer = moe_layer_factory(router=scatter_router) input_emb = jax.random.uniform( rng, (batch_size, max_seq_length, hidden_dim), minval=-10, maxval=10) # Mock the router weights to ensure both layers compute with the same # logits. mock_router_logits = jax.random.uniform( rng, (num_groups, tokens_per_group, num_experts), minval=-1, maxval=1) with mock.patch.object(masked_router, "router_weights", return_value=mock_router_logits): masked_outputs, _ = masked_moe_layer.init_with_output( rng, input_emb, deterministic=True) with mock.patch.object(scatter_router, "router_weights", return_value=mock_router_logits): scatter_outputs, _ = scatter_moe_layer.init_with_output( rng, input_emb, deterministic=True) expected_outputs = jnp.array([ [ [-8.16194050e-04, -3.92473085e-05], [-8.87976727e-04, 6.41788647e-05], [1.51725704e-04, 5.44631148e-05], [0.00000000e+00, 0.00000000e+00], ], [ [-1.63517136e-03, 7.32473345e-05], [6.99331111e-04, -4.98824847e-05], [-7.68527039e-04, -1.00117592e-04], [3.73630854e-03, 1.74387533e-04], ], [ [1.09393802e-03, 5.09395104e-05], [-4.27273808e-05, 1.12514383e-04], [3.19827022e-03, 1.41921133e-04], [2.31421960e-04, -2.57078882e-05], ], [ [0.00000000e+00, 0.00000000e+00], [1.65408337e-03, 1.62946199e-05], [2.29193736e-03, 1.07774074e-04], [-9.18464328e-04, -4.17242954e-05], ], ], dtype=jnp.float32) np.testing.assert_allclose(masked_outputs, expected_outputs, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(scatter_outputs, expected_outputs, rtol=1e-6, atol=1e-6)
def test_encoder_block_switch(self): batch_size = 2 max_seq_length = 14 num_tokens = batch_size * max_seq_length hidden_dim = 8 rng = jax.random.PRNGKey(0) init_rng, dropout_key, jitter_key = jax.random.split(rng, num=3) expert = layers.FeedForwardLayer(d_ff=4, dropout_rate=0.0, name="mlp") router = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0.01, num_selected_experts=1, batch_prioritized_routing=True, dtype=jnp.float32) moe_layer = layers.MoeLayer(num_experts=2, router=router, max_group_size=num_tokens, train_capacity_factor=1.0, eval_capacity_factor=1.0, expert=expert) mixing_layer = layers.LinearTransform() encoder_block = layers.EncoderBlock(feed_forward_sublayer=moe_layer, mixing_sublayer=mixing_layer, attention_sublayer=None) input_emb = jax.random.uniform( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) input_ids = jax.random.randint( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=20) params = init_layer_variables(rng, encoder_block, { "input_emb": input_emb, "input_ids": input_ids })["params"] expected_keys = { "mixing_sublayer", "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer" } self.assertEqual(params.keys(), expected_keys) outputs, state = encoder_block.apply({"params": params}, rngs={ "dropout": dropout_key, "jitter": jitter_key }, mutable=["intermediates"], input_emb=input_emb, input_ids=input_ids) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim)) self.assertIn("intermediates", state) jax.tree_util.tree_map( functools.partial(np.testing.assert_allclose, rtol=1e-5), state["intermediates"], FrozenDict({ "feed_forward_sublayer": { "diversity_metrics": layers.DiversityMetrics( auxiliary_loss=0.9997709, router_z_loss=0.48709542, fraction_tokens_left_behind=0.03571427, expert_usage=0.96428573, router_confidence=0.51779515) } }))