def _init_routers(self, layer): """Returns mask and scatter routers, only one of which will be non-empty.""" if self.config.dispatch_algorithm == DispatchAlgorithm.MASK_TOKENS_CHOOSE: return routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights( name=f"router_weights_{layer}"), jitter_noise=self.config.jitter_noise, num_selected_experts=self.config.num_selected_experts, batch_prioritized_routing=self.config. batch_prioritized_routing, dtype=layers.truncated_dtype()) elif (self.config.dispatch_algorithm == DispatchAlgorithm.SCATTER_TOKENS_CHOOSE): return routing.TokensChooseScatterRouter( router_weights=routing.RouterWeights( name=f"router_weights_{layer}"), jitter_noise=self.config.jitter_noise, num_selected_experts=self.config.num_selected_experts, batch_prioritized_routing=self.config. batch_prioritized_routing, dtype=layers.truncated_dtype()) elif (self.config.dispatch_algorithm == DispatchAlgorithm.MASK_EXPERTS_CHOOSE): return routing.ExpertsChooseMaskedRouter( router_weights=routing.RouterWeights( name=f"router_weights_{layer}"), jitter_noise=self.config.jitter_noise, dtype=layers.truncated_dtype()) else: raise ValueError( f"Unrecognized dispatch_algorithm: {self.config.dispatch_algorithm}" )
def test_moe_layer_runs(self, dispatch): batch_size = 3 max_seq_length = 4 num_tokens = batch_size * max_seq_length hidden_dim = 2 num_experts = 4 rng = jax.random.PRNGKey(0) if dispatch == "mask": router = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) else: router = routing.TokensChooseScatterRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp") moe_layer = layers.MoeLayer(num_experts=num_experts, max_group_size=num_tokens, router=router, train_capacity_factor=1.5, eval_capacity_factor=1.5, expert=expert, axis_name="batch") init_batch = { "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32) } params = init_layer_variables(rng, moe_layer, init_batch)["params"] expected_keys = {"router", "expert"} self.assertEqual(params.keys(), expected_keys) dropout_rng, jitter_rng, init_rng = jax.random.split(rng, num=3) input_emb = jax.random.uniform( init_rng, (batch_size, max_seq_length, hidden_dim), minval=-10, maxval=10) actual_outputs, state = moe_layer.apply({"params": params}, rngs={ "dropout": dropout_rng, "jitter": jitter_rng }, mutable=["intermediates"], input_emb=input_emb) self.assertEqual(actual_outputs.shape, (batch_size, max_seq_length, hidden_dim)) self.assertIn("diversity_metrics", state["intermediates"])
def test_experts_choose_mask_router(self): num_groups = 2 tokens_per_group = 4 hidden_dim = 3 num_experts = 2 expert_capacity = 2 rng = jax.random.PRNGKey(0) token_inputs = jax.random.uniform( rng, (num_groups, tokens_per_group, hidden_dim), minval=0, maxval=1) router_mask, _ = routing.ExpertsChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0), token_inputs, num_experts, expert_capacity) expected_mask = jnp.array([ [ [[0, 1], [1, 0]], [[0, 0], [0, 0]], [[0, 0], [0, 1]], [[1, 0], [0, 0]], ], [ [[0, 0], [0, 1]], [[0, 0], [1, 0]], [[0, 1], [0, 0]], [[1, 0], [0, 0]], ], ], dtype=jnp.int32) np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask) expected_weights = jnp.array([ [ [[0., 0.50390625], [0.49804688, 0.]], [[0., 0.], [0., 0.]], [[0., 0.], [0., 0.49804688]], [[0.5078125, 0.], [0., 0.]], ], [ [[0., 0.], [0., 0.49414062]], [[0., 0.], [0.49609375, 0.]], [[0., 0.5078125], [0., 0.]], [[0.51171875, 0.], [0., 0.]], ], ], dtype=jnp.float32) np.testing.assert_allclose(router_mask.combine_array, expected_weights) # Auxiliary loss is always 0. for experts choose tokens routing. self.assertEqual(router_mask.auxiliary_loss, 0.) self.assertEqual(router_mask.router_z_loss, 0.48657227)
def test_tokens_choose_one_expert_mask_router_no_bpr(self): num_groups = 2 tokens_per_group = 3 hidden_dim = 4 num_experts = 2 num_selected_experts = 1 # Switch routing case expert_capacity = 1 # Total capacity = 2*2*1 = 4 < num_tokens rng = jax.random.PRNGKey(0) token_inputs = jax.random.uniform( rng, (num_groups, tokens_per_group, hidden_dim), minval=0, maxval=1) router_mask, _ = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), num_selected_experts=num_selected_experts, jitter_noise=0., batch_prioritized_routing=False, dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0), token_inputs, num_experts, expert_capacity) expected_mask = jnp.array([ [ [[False], [True]], [[False], [False]], [[False], [False]], ], [ [[False], [True]], [[True], [False]], [[False], [False]], ], ], dtype=jnp.bool_) np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask) expected_weights = jnp.array([ [ [[0.], [0.5078125]], [[0.], [0.]], [[0.], [0.]], ], [ [[0.], [0.50390625]], [[0.5], [0.]], [[0.], [0.]], ], ], dtype=jnp.float32) np.testing.assert_allclose(router_mask.combine_array, expected_weights) self.assertEqual(router_mask.auxiliary_loss, 1.0065105) self.assertEqual(router_mask.router_z_loss, 0.4716797)
def test_tokens_choose_one_expert_scatter_router_no_bpr(self): num_groups = 2 tokens_per_group = 4 hidden_dim = 4 expert_capacity = 2 num_experts = 4 num_selected_experts = 1 # Switch routing case rng = jax.random.PRNGKey(0) token_inputs = jax.random.uniform( rng, (num_groups, tokens_per_group, hidden_dim), minval=0, maxval=1) router_indices, _ = routing.TokensChooseScatterRouter( router_weights=routing.RouterWeights(name="router_weights"), num_selected_experts=num_selected_experts, jitter_noise=0.01, batch_prioritized_routing=False, dtype=jnp.float32).init_with_output( { "params": jax.random.PRNGKey(0), "jitter": jax.random.PRNGKey(0) }, token_inputs, num_experts, expert_capacity) expected_indices = jnp.array([ [ [[1, 0]], [[0, 0]], [[1, 1]], [[1, 2]], ], [ [[0, 0]], [[1, 0]], [[1, 1]], [[1, 2]], ], ], dtype=jnp.int32) np.testing.assert_allclose(router_indices.dispatch_indices, expected_indices) expected_weights = jnp.array([ [[0.2578125], [0.25390625], [0.25585938], [0.]], [[0.2578125], [0.25390625], [0.25390625], [0.]], ], dtype=jnp.float32) np.testing.assert_allclose(router_indices.combine_weights, expected_weights) self.assertEqual(router_indices.auxiliary_loss, 1.0168457) self.assertEqual(router_indices.router_z_loss, 1.9111328)
def test_num_groups(self, max_group_size, num_tokens, num_experts, expected_num_groups): expert = layers.FeedForwardLayer(d_ff=2) router = routing.ExpertsChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0., dtype=jnp.float32) moe_layer = layers.MoeLayer(num_experts=num_experts, router=router, max_group_size=max_group_size, train_capacity_factor=1., eval_capacity_factor=1., expert=expert) num_groups = moe_layer._num_groups(num_tokens, max_group_size) self.assertEqual(num_groups, expected_num_groups)
def test_scatter_mask_dispatch_equal(self): batch_size = 4 max_seq_length = 4 hidden_dim = 2 num_experts = 2 tokens_per_group = 8 num_groups = batch_size * max_seq_length // tokens_per_group rng = jax.random.PRNGKey(0) expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp") moe_layer_factory = functools.partial( layers.MoeLayer, num_experts=num_experts, max_group_size=tokens_per_group, train_capacity_factor=1., eval_capacity_factor=1., expert=expert, split_params=False) # Ensures all experts start with same params router_weights = routing.RouterWeights(name="router_weights") masked_router = routing.TokensChooseMaskedRouter( router_weights=router_weights, jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) masked_moe_layer = moe_layer_factory(router=masked_router) scatter_router = routing.TokensChooseScatterRouter( router_weights=router_weights, jitter_noise=0., num_selected_experts=2, batch_prioritized_routing=True, dtype=jnp.float32) scatter_moe_layer = moe_layer_factory(router=scatter_router) input_emb = jax.random.uniform( rng, (batch_size, max_seq_length, hidden_dim), minval=-10, maxval=10) # Mock the router weights to ensure both layers compute with the same # logits. mock_router_logits = jax.random.uniform( rng, (num_groups, tokens_per_group, num_experts), minval=-1, maxval=1) with mock.patch.object(masked_router, "router_weights", return_value=mock_router_logits): masked_outputs, _ = masked_moe_layer.init_with_output( rng, input_emb, deterministic=True) with mock.patch.object(scatter_router, "router_weights", return_value=mock_router_logits): scatter_outputs, _ = scatter_moe_layer.init_with_output( rng, input_emb, deterministic=True) expected_outputs = jnp.array([ [ [-8.16194050e-04, -3.92473085e-05], [-8.87976727e-04, 6.41788647e-05], [1.51725704e-04, 5.44631148e-05], [0.00000000e+00, 0.00000000e+00], ], [ [-1.63517136e-03, 7.32473345e-05], [6.99331111e-04, -4.98824847e-05], [-7.68527039e-04, -1.00117592e-04], [3.73630854e-03, 1.74387533e-04], ], [ [1.09393802e-03, 5.09395104e-05], [-4.27273808e-05, 1.12514383e-04], [3.19827022e-03, 1.41921133e-04], [2.31421960e-04, -2.57078882e-05], ], [ [0.00000000e+00, 0.00000000e+00], [1.65408337e-03, 1.62946199e-05], [2.29193736e-03, 1.07774074e-04], [-9.18464328e-04, -4.17242954e-05], ], ], dtype=jnp.float32) np.testing.assert_allclose(masked_outputs, expected_outputs, rtol=1e-6, atol=1e-6) np.testing.assert_allclose(scatter_outputs, expected_outputs, rtol=1e-6, atol=1e-6)
def test_encoder_block_switch(self): batch_size = 2 max_seq_length = 14 num_tokens = batch_size * max_seq_length hidden_dim = 8 rng = jax.random.PRNGKey(0) init_rng, dropout_key, jitter_key = jax.random.split(rng, num=3) expert = layers.FeedForwardLayer(d_ff=4, dropout_rate=0.0, name="mlp") router = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), jitter_noise=0.01, num_selected_experts=1, batch_prioritized_routing=True, dtype=jnp.float32) moe_layer = layers.MoeLayer(num_experts=2, router=router, max_group_size=num_tokens, train_capacity_factor=1.0, eval_capacity_factor=1.0, expert=expert) mixing_layer = layers.LinearTransform() encoder_block = layers.EncoderBlock(feed_forward_sublayer=moe_layer, mixing_sublayer=mixing_layer, attention_sublayer=None) input_emb = jax.random.uniform( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=10) input_ids = jax.random.randint( init_rng, (batch_size, max_seq_length, hidden_dim), minval=0, maxval=20) params = init_layer_variables(rng, encoder_block, { "input_emb": input_emb, "input_ids": input_ids })["params"] expected_keys = { "mixing_sublayer", "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer" } self.assertEqual(params.keys(), expected_keys) outputs, state = encoder_block.apply({"params": params}, rngs={ "dropout": dropout_key, "jitter": jitter_key }, mutable=["intermediates"], input_emb=input_emb, input_ids=input_ids) self.assertEqual(outputs.shape, (batch_size, max_seq_length, hidden_dim)) self.assertIn("intermediates", state) jax.tree_util.tree_map( functools.partial(np.testing.assert_allclose, rtol=1e-5), state["intermediates"], FrozenDict({ "feed_forward_sublayer": { "diversity_metrics": layers.DiversityMetrics( auxiliary_loss=0.9997709, router_z_loss=0.48709542, fraction_tokens_left_behind=0.03571427, expert_usage=0.96428573, router_confidence=0.51779515) } }))
def test_scatter_and_mask_dispatch_equal(self): num_groups = 2 tokens_per_group = 4 hidden_dim = 3 num_experts = 3 num_selected_experts = 1 expert_capacity = 2 rng = jax.random.PRNGKey(0) router_weights = routing.RouterWeights(name="router_weights") token_inputs = jax.random.uniform( rng, (num_groups, tokens_per_group, hidden_dim), minval=0, maxval=1) router_mask, _ = routing.TokensChooseMaskedRouter( router_weights, num_selected_experts=num_selected_experts, jitter_noise=0., batch_prioritized_routing=True, dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0), token_inputs, num_experts, expert_capacity) # Manipulate masked router dispatch and combine arrays to match format of # scatter router output. # Ignore capacity. Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS] masked_router_says_dispatched = jnp.max(router_mask.dispatch_mask, axis=-1) # Ignore particular expert and capacity for combine array. # Shape: [NUM_GROUPS, TOKENS_PER_GROUP] masked_router_combine_array = jnp.max(router_mask.combine_array, axis=(-1, -2)) router_indices, _ = routing.TokensChooseScatterRouter( router_weights, num_selected_experts=num_selected_experts, jitter_noise=0., batch_prioritized_routing=True, dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0), token_inputs, num_experts, expert_capacity) # Manipulate scatter router dispatch and combine indices to match format of # masked router output. # Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_SELECTED_EXPERTS] successfully_routed = router_indices.dispatch_indices[ Ellipsis, 1] < expert_capacity # Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS] scatter_router_says_dispatched = successfully_routed * jax.nn.one_hot( router_indices.dispatch_indices[Ellipsis, 0].squeeze(axis=-1), num_experts) # Remove trivial selected expert axis. # Shape: [NUM_GROUPS, TOKENS_PER_GROUP]. scatter_router_combine_array = router_indices.combine_weights.squeeze( axis=-1) np.testing.assert_allclose(masked_router_says_dispatched, scatter_router_says_dispatched) np.testing.assert_allclose(masked_router_combine_array, scatter_router_combine_array) np.testing.assert_allclose(router_mask.auxiliary_loss, router_indices.auxiliary_loss) np.testing.assert_allclose(router_mask.router_z_loss, router_indices.router_z_loss)
def test_tokens_choose_multiple_experts_mask_router(self): num_groups = 2 tokens_per_group = 4 hidden_dim = 3 num_experts = 3 num_selected_experts = 2 expert_capacity = 1 rng = jax.random.PRNGKey(0) token_inputs = jax.random.uniform( rng, (num_groups, tokens_per_group, hidden_dim), minval=0, maxval=1) router_mask, _ = routing.TokensChooseMaskedRouter( router_weights=routing.RouterWeights(name="router_weights"), num_selected_experts=num_selected_experts, jitter_noise=0.01, batch_prioritized_routing=True, dtype=jnp.float32).init_with_output( { "params": jax.random.PRNGKey(0), "jitter": jax.random.PRNGKey(0) }, token_inputs, num_experts, expert_capacity) expected_mask = jnp.array([ [ [[False], [False], [False]], [[False], [False], [False]], [[False], [False], [False]], [[True], [True], [False]], ], [ [[False], [False], [False]], [[False], [False], [False]], [[False], [False], [False]], [[True], [True], [False]], ], ], dtype=jnp.bool_) np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask) expected_weights = jnp.array([ [ [[0.], [0.], [0.]], [[0.], [0.], [0.]], [[0.], [0.], [0.]], [[0.32617188], [0.3515625], [0.]], ], [ [[0.], [0.], [0.]], [[0.], [0.], [0.]], [[0.], [0.], [0.]], [[0.32226562], [0.36328125], [0.]], ], ], dtype=jnp.float32) np.testing.assert_allclose(router_mask.combine_array, expected_weights) self.assertEqual(router_mask.auxiliary_loss, 2.025879) self.assertEqual(router_mask.router_z_loss, 1.2324219)