Exemplo n.º 1
0
 def _init_routers(self, layer):
     """Returns mask and scatter routers, only one of which will be non-empty."""
     if self.config.dispatch_algorithm == DispatchAlgorithm.MASK_TOKENS_CHOOSE:
         return routing.TokensChooseMaskedRouter(
             router_weights=routing.RouterWeights(
                 name=f"router_weights_{layer}"),
             jitter_noise=self.config.jitter_noise,
             num_selected_experts=self.config.num_selected_experts,
             batch_prioritized_routing=self.config.
             batch_prioritized_routing,
             dtype=layers.truncated_dtype())
     elif (self.config.dispatch_algorithm ==
           DispatchAlgorithm.SCATTER_TOKENS_CHOOSE):
         return routing.TokensChooseScatterRouter(
             router_weights=routing.RouterWeights(
                 name=f"router_weights_{layer}"),
             jitter_noise=self.config.jitter_noise,
             num_selected_experts=self.config.num_selected_experts,
             batch_prioritized_routing=self.config.
             batch_prioritized_routing,
             dtype=layers.truncated_dtype())
     elif (self.config.dispatch_algorithm ==
           DispatchAlgorithm.MASK_EXPERTS_CHOOSE):
         return routing.ExpertsChooseMaskedRouter(
             router_weights=routing.RouterWeights(
                 name=f"router_weights_{layer}"),
             jitter_noise=self.config.jitter_noise,
             dtype=layers.truncated_dtype())
     else:
         raise ValueError(
             f"Unrecognized dispatch_algorithm: {self.config.dispatch_algorithm}"
         )
Exemplo n.º 2
0
    def test_moe_layer_runs(self, dispatch):
        batch_size = 3
        max_seq_length = 4
        num_tokens = batch_size * max_seq_length
        hidden_dim = 2
        num_experts = 4
        rng = jax.random.PRNGKey(0)

        if dispatch == "mask":
            router = routing.TokensChooseMaskedRouter(
                router_weights=routing.RouterWeights(name="router_weights"),
                jitter_noise=0.,
                num_selected_experts=2,
                batch_prioritized_routing=True,
                dtype=jnp.float32)
        else:
            router = routing.TokensChooseScatterRouter(
                router_weights=routing.RouterWeights(name="router_weights"),
                jitter_noise=0.,
                num_selected_experts=2,
                batch_prioritized_routing=True,
                dtype=jnp.float32)

        expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp")
        moe_layer = layers.MoeLayer(num_experts=num_experts,
                                    max_group_size=num_tokens,
                                    router=router,
                                    train_capacity_factor=1.5,
                                    eval_capacity_factor=1.5,
                                    expert=expert,
                                    axis_name="batch")
        init_batch = {
            "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, moe_layer, init_batch)["params"]

        expected_keys = {"router", "expert"}
        self.assertEqual(params.keys(), expected_keys)

        dropout_rng, jitter_rng, init_rng = jax.random.split(rng, num=3)
        input_emb = jax.random.uniform(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=-10,
            maxval=10)
        actual_outputs, state = moe_layer.apply({"params": params},
                                                rngs={
                                                    "dropout": dropout_rng,
                                                    "jitter": jitter_rng
                                                },
                                                mutable=["intermediates"],
                                                input_emb=input_emb)

        self.assertEqual(actual_outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))

        self.assertIn("diversity_metrics", state["intermediates"])
Exemplo n.º 3
0
    def test_experts_choose_mask_router(self):
        num_groups = 2
        tokens_per_group = 4
        hidden_dim = 3
        num_experts = 2
        expert_capacity = 2
        rng = jax.random.PRNGKey(0)

        token_inputs = jax.random.uniform(
            rng, (num_groups, tokens_per_group, hidden_dim),
            minval=0,
            maxval=1)
        router_mask, _ = routing.ExpertsChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            jitter_noise=0.,
            dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0),
                                                token_inputs, num_experts,
                                                expert_capacity)

        expected_mask = jnp.array([
            [
                [[0, 1], [1, 0]],
                [[0, 0], [0, 0]],
                [[0, 0], [0, 1]],
                [[1, 0], [0, 0]],
            ],
            [
                [[0, 0], [0, 1]],
                [[0, 0], [1, 0]],
                [[0, 1], [0, 0]],
                [[1, 0], [0, 0]],
            ],
        ],
                                  dtype=jnp.int32)

        np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask)

        expected_weights = jnp.array([
            [
                [[0., 0.50390625], [0.49804688, 0.]],
                [[0., 0.], [0., 0.]],
                [[0., 0.], [0., 0.49804688]],
                [[0.5078125, 0.], [0., 0.]],
            ],
            [
                [[0., 0.], [0., 0.49414062]],
                [[0., 0.], [0.49609375, 0.]],
                [[0., 0.5078125], [0., 0.]],
                [[0.51171875, 0.], [0., 0.]],
            ],
        ],
                                     dtype=jnp.float32)
        np.testing.assert_allclose(router_mask.combine_array, expected_weights)

        # Auxiliary loss is always 0. for experts choose tokens routing.
        self.assertEqual(router_mask.auxiliary_loss, 0.)
        self.assertEqual(router_mask.router_z_loss, 0.48657227)
Exemplo n.º 4
0
    def test_tokens_choose_one_expert_mask_router_no_bpr(self):
        num_groups = 2
        tokens_per_group = 3
        hidden_dim = 4
        num_experts = 2
        num_selected_experts = 1  # Switch routing case
        expert_capacity = 1  # Total capacity = 2*2*1 = 4 < num_tokens
        rng = jax.random.PRNGKey(0)

        token_inputs = jax.random.uniform(
            rng, (num_groups, tokens_per_group, hidden_dim),
            minval=0,
            maxval=1)
        router_mask, _ = routing.TokensChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            num_selected_experts=num_selected_experts,
            jitter_noise=0.,
            batch_prioritized_routing=False,
            dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0),
                                                token_inputs, num_experts,
                                                expert_capacity)

        expected_mask = jnp.array([
            [
                [[False], [True]],
                [[False], [False]],
                [[False], [False]],
            ],
            [
                [[False], [True]],
                [[True], [False]],
                [[False], [False]],
            ],
        ],
                                  dtype=jnp.bool_)

        np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask)

        expected_weights = jnp.array([
            [
                [[0.], [0.5078125]],
                [[0.], [0.]],
                [[0.], [0.]],
            ],
            [
                [[0.], [0.50390625]],
                [[0.5], [0.]],
                [[0.], [0.]],
            ],
        ],
                                     dtype=jnp.float32)
        np.testing.assert_allclose(router_mask.combine_array, expected_weights)

        self.assertEqual(router_mask.auxiliary_loss, 1.0065105)
        self.assertEqual(router_mask.router_z_loss, 0.4716797)
Exemplo n.º 5
0
    def test_tokens_choose_one_expert_scatter_router_no_bpr(self):
        num_groups = 2
        tokens_per_group = 4
        hidden_dim = 4
        expert_capacity = 2
        num_experts = 4
        num_selected_experts = 1  # Switch routing case
        rng = jax.random.PRNGKey(0)

        token_inputs = jax.random.uniform(
            rng, (num_groups, tokens_per_group, hidden_dim),
            minval=0,
            maxval=1)
        router_indices, _ = routing.TokensChooseScatterRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            num_selected_experts=num_selected_experts,
            jitter_noise=0.01,
            batch_prioritized_routing=False,
            dtype=jnp.float32).init_with_output(
                {
                    "params": jax.random.PRNGKey(0),
                    "jitter": jax.random.PRNGKey(0)
                }, token_inputs, num_experts, expert_capacity)

        expected_indices = jnp.array([
            [
                [[1, 0]],
                [[0, 0]],
                [[1, 1]],
                [[1, 2]],
            ],
            [
                [[0, 0]],
                [[1, 0]],
                [[1, 1]],
                [[1, 2]],
            ],
        ],
                                     dtype=jnp.int32)

        np.testing.assert_allclose(router_indices.dispatch_indices,
                                   expected_indices)

        expected_weights = jnp.array([
            [[0.2578125], [0.25390625], [0.25585938], [0.]],
            [[0.2578125], [0.25390625], [0.25390625], [0.]],
        ],
                                     dtype=jnp.float32)
        np.testing.assert_allclose(router_indices.combine_weights,
                                   expected_weights)

        self.assertEqual(router_indices.auxiliary_loss, 1.0168457)
        self.assertEqual(router_indices.router_z_loss, 1.9111328)
Exemplo n.º 6
0
    def test_num_groups(self, max_group_size, num_tokens, num_experts,
                        expected_num_groups):
        expert = layers.FeedForwardLayer(d_ff=2)
        router = routing.ExpertsChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            jitter_noise=0.,
            dtype=jnp.float32)
        moe_layer = layers.MoeLayer(num_experts=num_experts,
                                    router=router,
                                    max_group_size=max_group_size,
                                    train_capacity_factor=1.,
                                    eval_capacity_factor=1.,
                                    expert=expert)

        num_groups = moe_layer._num_groups(num_tokens, max_group_size)
        self.assertEqual(num_groups, expected_num_groups)
Exemplo n.º 7
0
    def test_scatter_mask_dispatch_equal(self):
        batch_size = 4
        max_seq_length = 4
        hidden_dim = 2
        num_experts = 2
        tokens_per_group = 8
        num_groups = batch_size * max_seq_length // tokens_per_group

        rng = jax.random.PRNGKey(0)

        expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp")
        moe_layer_factory = functools.partial(
            layers.MoeLayer,
            num_experts=num_experts,
            max_group_size=tokens_per_group,
            train_capacity_factor=1.,
            eval_capacity_factor=1.,
            expert=expert,
            split_params=False)  # Ensures all experts start with same params

        router_weights = routing.RouterWeights(name="router_weights")
        masked_router = routing.TokensChooseMaskedRouter(
            router_weights=router_weights,
            jitter_noise=0.,
            num_selected_experts=2,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        masked_moe_layer = moe_layer_factory(router=masked_router)
        scatter_router = routing.TokensChooseScatterRouter(
            router_weights=router_weights,
            jitter_noise=0.,
            num_selected_experts=2,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        scatter_moe_layer = moe_layer_factory(router=scatter_router)

        input_emb = jax.random.uniform(
            rng, (batch_size, max_seq_length, hidden_dim),
            minval=-10,
            maxval=10)

        # Mock the router weights to ensure both layers compute with the same
        # logits.
        mock_router_logits = jax.random.uniform(
            rng, (num_groups, tokens_per_group, num_experts),
            minval=-1,
            maxval=1)
        with mock.patch.object(masked_router,
                               "router_weights",
                               return_value=mock_router_logits):
            masked_outputs, _ = masked_moe_layer.init_with_output(
                rng, input_emb, deterministic=True)
        with mock.patch.object(scatter_router,
                               "router_weights",
                               return_value=mock_router_logits):
            scatter_outputs, _ = scatter_moe_layer.init_with_output(
                rng, input_emb, deterministic=True)

        expected_outputs = jnp.array([
            [
                [-8.16194050e-04, -3.92473085e-05],
                [-8.87976727e-04, 6.41788647e-05],
                [1.51725704e-04, 5.44631148e-05],
                [0.00000000e+00, 0.00000000e+00],
            ],
            [
                [-1.63517136e-03, 7.32473345e-05],
                [6.99331111e-04, -4.98824847e-05],
                [-7.68527039e-04, -1.00117592e-04],
                [3.73630854e-03, 1.74387533e-04],
            ],
            [
                [1.09393802e-03, 5.09395104e-05],
                [-4.27273808e-05, 1.12514383e-04],
                [3.19827022e-03, 1.41921133e-04],
                [2.31421960e-04, -2.57078882e-05],
            ],
            [
                [0.00000000e+00, 0.00000000e+00],
                [1.65408337e-03, 1.62946199e-05],
                [2.29193736e-03, 1.07774074e-04],
                [-9.18464328e-04, -4.17242954e-05],
            ],
        ],
                                     dtype=jnp.float32)

        np.testing.assert_allclose(masked_outputs,
                                   expected_outputs,
                                   rtol=1e-6,
                                   atol=1e-6)
        np.testing.assert_allclose(scatter_outputs,
                                   expected_outputs,
                                   rtol=1e-6,
                                   atol=1e-6)
Exemplo n.º 8
0
    def test_encoder_block_switch(self):
        batch_size = 2
        max_seq_length = 14
        num_tokens = batch_size * max_seq_length
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)
        init_rng, dropout_key, jitter_key = jax.random.split(rng, num=3)

        expert = layers.FeedForwardLayer(d_ff=4, dropout_rate=0.0, name="mlp")
        router = routing.TokensChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            jitter_noise=0.01,
            num_selected_experts=1,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        moe_layer = layers.MoeLayer(num_experts=2,
                                    router=router,
                                    max_group_size=num_tokens,
                                    train_capacity_factor=1.0,
                                    eval_capacity_factor=1.0,
                                    expert=expert)

        mixing_layer = layers.LinearTransform()
        encoder_block = layers.EncoderBlock(feed_forward_sublayer=moe_layer,
                                            mixing_sublayer=mixing_layer,
                                            attention_sublayer=None)
        input_emb = jax.random.uniform(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=10)
        input_ids = jax.random.randint(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=20)
        params = init_layer_variables(rng, encoder_block, {
            "input_emb": input_emb,
            "input_ids": input_ids
        })["params"]

        expected_keys = {
            "mixing_sublayer", "mixing_layer_norm", "output_layer_norm",
            "feed_forward_sublayer"
        }
        self.assertEqual(params.keys(), expected_keys)

        outputs, state = encoder_block.apply({"params": params},
                                             rngs={
                                                 "dropout": dropout_key,
                                                 "jitter": jitter_key
                                             },
                                             mutable=["intermediates"],
                                             input_emb=input_emb,
                                             input_ids=input_ids)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))

        self.assertIn("intermediates", state)
        jax.tree_util.tree_map(
            functools.partial(np.testing.assert_allclose, rtol=1e-5),
            state["intermediates"],
            FrozenDict({
                "feed_forward_sublayer": {
                    "diversity_metrics":
                    layers.DiversityMetrics(
                        auxiliary_loss=0.9997709,
                        router_z_loss=0.48709542,
                        fraction_tokens_left_behind=0.03571427,
                        expert_usage=0.96428573,
                        router_confidence=0.51779515)
                }
            }))
Exemplo n.º 9
0
    def test_scatter_and_mask_dispatch_equal(self):
        num_groups = 2
        tokens_per_group = 4
        hidden_dim = 3
        num_experts = 3
        num_selected_experts = 1
        expert_capacity = 2
        rng = jax.random.PRNGKey(0)

        router_weights = routing.RouterWeights(name="router_weights")

        token_inputs = jax.random.uniform(
            rng, (num_groups, tokens_per_group, hidden_dim),
            minval=0,
            maxval=1)

        router_mask, _ = routing.TokensChooseMaskedRouter(
            router_weights,
            num_selected_experts=num_selected_experts,
            jitter_noise=0.,
            batch_prioritized_routing=True,
            dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0),
                                                token_inputs, num_experts,
                                                expert_capacity)
        # Manipulate masked router dispatch and combine arrays to match format of
        # scatter router output.
        # Ignore capacity. Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS]
        masked_router_says_dispatched = jnp.max(router_mask.dispatch_mask,
                                                axis=-1)
        # Ignore particular expert and capacity for combine array.
        # Shape: [NUM_GROUPS, TOKENS_PER_GROUP]
        masked_router_combine_array = jnp.max(router_mask.combine_array,
                                              axis=(-1, -2))

        router_indices, _ = routing.TokensChooseScatterRouter(
            router_weights,
            num_selected_experts=num_selected_experts,
            jitter_noise=0.,
            batch_prioritized_routing=True,
            dtype=jnp.float32).init_with_output(jax.random.PRNGKey(0),
                                                token_inputs, num_experts,
                                                expert_capacity)
        # Manipulate scatter router dispatch and combine indices to match format of
        # masked router output.
        # Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_SELECTED_EXPERTS]
        successfully_routed = router_indices.dispatch_indices[
            Ellipsis, 1] < expert_capacity
        # Shape: [NUM_GROUPS, TOKENS_PER_GROUP, NUM_EXPERTS]
        scatter_router_says_dispatched = successfully_routed * jax.nn.one_hot(
            router_indices.dispatch_indices[Ellipsis, 0].squeeze(axis=-1),
            num_experts)
        # Remove trivial selected expert axis.
        # Shape: [NUM_GROUPS, TOKENS_PER_GROUP].
        scatter_router_combine_array = router_indices.combine_weights.squeeze(
            axis=-1)

        np.testing.assert_allclose(masked_router_says_dispatched,
                                   scatter_router_says_dispatched)
        np.testing.assert_allclose(masked_router_combine_array,
                                   scatter_router_combine_array)
        np.testing.assert_allclose(router_mask.auxiliary_loss,
                                   router_indices.auxiliary_loss)
        np.testing.assert_allclose(router_mask.router_z_loss,
                                   router_indices.router_z_loss)
Exemplo n.º 10
0
    def test_tokens_choose_multiple_experts_mask_router(self):
        num_groups = 2
        tokens_per_group = 4
        hidden_dim = 3
        num_experts = 3
        num_selected_experts = 2
        expert_capacity = 1
        rng = jax.random.PRNGKey(0)

        token_inputs = jax.random.uniform(
            rng, (num_groups, tokens_per_group, hidden_dim),
            minval=0,
            maxval=1)
        router_mask, _ = routing.TokensChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            num_selected_experts=num_selected_experts,
            jitter_noise=0.01,
            batch_prioritized_routing=True,
            dtype=jnp.float32).init_with_output(
                {
                    "params": jax.random.PRNGKey(0),
                    "jitter": jax.random.PRNGKey(0)
                }, token_inputs, num_experts, expert_capacity)

        expected_mask = jnp.array([
            [
                [[False], [False], [False]],
                [[False], [False], [False]],
                [[False], [False], [False]],
                [[True], [True], [False]],
            ],
            [
                [[False], [False], [False]],
                [[False], [False], [False]],
                [[False], [False], [False]],
                [[True], [True], [False]],
            ],
        ],
                                  dtype=jnp.bool_)

        np.testing.assert_allclose(router_mask.dispatch_mask, expected_mask)

        expected_weights = jnp.array([
            [
                [[0.], [0.], [0.]],
                [[0.], [0.], [0.]],
                [[0.], [0.], [0.]],
                [[0.32617188], [0.3515625], [0.]],
            ],
            [
                [[0.], [0.], [0.]],
                [[0.], [0.], [0.]],
                [[0.], [0.], [0.]],
                [[0.32226562], [0.36328125], [0.]],
            ],
        ],
                                     dtype=jnp.float32)
        np.testing.assert_allclose(router_mask.combine_array, expected_weights)

        self.assertEqual(router_mask.auxiliary_loss, 2.025879)
        self.assertEqual(router_mask.router_z_loss, 1.2324219)