def _init_feed_forward_sublayer(self, layer):
        """Initializes config-dependent feed-forward sublayer."""
        if self._is_moe_layer(layer):
            expert = layers.FeedForwardLayer(
                d_ff=self.config.expert_d_ff,
                dropout_rate=self.config.expert_dropout_rate,
                dtype=layers.truncated_dtype(),
                name=f"expert_{layer}")
            ff_sublayer = layers.MoeLayer(
                num_experts=self.config.num_experts,
                router=self._init_routers(layer),
                max_group_size=self.config.max_group_size,
                train_capacity_factor=self.config.train_capacity_factor,
                eval_capacity_factor=self.config.eval_capacity_factor,
                expert=expert,
                min_expert_capacity=self.config.min_expert_capacity,
                dropout_rate=self.config.expert_dropout_rate,
                dtype=self.config.dtype,
                name=f"moe_{layer}")
        else:
            ff_sublayer = layers.FeedForwardLayer(
                d_ff=self.config.d_ff,
                dropout_rate=self.config.dropout_rate,
                dtype=self.config.dtype,
                name=f"feed_forward_{layer}")

        return ff_sublayer
Exemple #2
0
    def test_feed_forward_layer(self):
        batch_size = 3
        max_seq_length = 16
        hidden_dim = 12
        rng = jax.random.PRNGKey(0)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.1)
        init_batch = {
            "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, feed_forward_layer,
                                      init_batch)["params"]

        expected_keys = {"intermediate", "output"}
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        input_emb = jax.random.randint(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=10)
        outputs = feed_forward_layer.apply({"params": params},
                                           rngs={"dropout": rng},
                                           input_emb=input_emb)

        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))
Exemple #3
0
    def test_moe_layer_runs(self, dispatch):
        batch_size = 3
        max_seq_length = 4
        num_tokens = batch_size * max_seq_length
        hidden_dim = 2
        num_experts = 4
        rng = jax.random.PRNGKey(0)

        if dispatch == "mask":
            router = routing.TokensChooseMaskedRouter(
                router_weights=routing.RouterWeights(name="router_weights"),
                jitter_noise=0.,
                num_selected_experts=2,
                batch_prioritized_routing=True,
                dtype=jnp.float32)
        else:
            router = routing.TokensChooseScatterRouter(
                router_weights=routing.RouterWeights(name="router_weights"),
                jitter_noise=0.,
                num_selected_experts=2,
                batch_prioritized_routing=True,
                dtype=jnp.float32)

        expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp")
        moe_layer = layers.MoeLayer(num_experts=num_experts,
                                    max_group_size=num_tokens,
                                    router=router,
                                    train_capacity_factor=1.5,
                                    eval_capacity_factor=1.5,
                                    expert=expert,
                                    axis_name="batch")
        init_batch = {
            "input_emb": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, moe_layer, init_batch)["params"]

        expected_keys = {"router", "expert"}
        self.assertEqual(params.keys(), expected_keys)

        dropout_rng, jitter_rng, init_rng = jax.random.split(rng, num=3)
        input_emb = jax.random.uniform(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=-10,
            maxval=10)
        actual_outputs, state = moe_layer.apply({"params": params},
                                                rngs={
                                                    "dropout": dropout_rng,
                                                    "jitter": jitter_rng
                                                },
                                                mutable=["intermediates"],
                                                input_emb=input_emb)

        self.assertEqual(actual_outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))

        self.assertIn("diversity_metrics", state["intermediates"])
Exemple #4
0
    def test_num_groups(self, max_group_size, num_tokens, num_experts,
                        expected_num_groups):
        expert = layers.FeedForwardLayer(d_ff=2)
        router = routing.ExpertsChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            jitter_noise=0.,
            dtype=jnp.float32)
        moe_layer = layers.MoeLayer(num_experts=num_experts,
                                    router=router,
                                    max_group_size=max_group_size,
                                    train_capacity_factor=1.,
                                    eval_capacity_factor=1.,
                                    expert=expert)

        num_groups = moe_layer._num_groups(num_tokens, max_group_size)
        self.assertEqual(num_groups, expected_num_groups)
Exemple #5
0
    def test_construct_encoder_block_correctly(self):
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8)
        mixing_layer = layers.LinearTransform()
        attention_layer = layers.AttentionLayer(num_heads=1, d_model=2)

        init_batch = {
            "input_emb": jnp.ones((1, max_seq_length, hidden_dim),
                                  jnp.float32),
            "input_ids": jnp.ones((1, max_seq_length), jnp.int32),
        }

        # Success case.
        encoder_block = layers.EncoderBlock(
            feed_forward_sublayer=feed_forward_layer,
            mixing_sublayer=mixing_layer,
            attention_sublayer=None)
        _ = init_layer_variables(rng, encoder_block, init_batch)

        # Failure case.
        with self.assertRaisesRegex(
                ValueError, "One, and only one, of {self.mixing_sublayer, "
                "self.attention_sublayer} must be nonempty"):
            encoder_block = layers.EncoderBlock(
                feed_forward_sublayer=feed_forward_layer,
                mixing_sublayer=mixing_layer,
                attention_sublayer=attention_layer)
            _ = init_layer_variables(rng, encoder_block, init_batch)

        # Failure case.
        with self.assertRaisesRegex(
                ValueError, "One, and only one, of {self.mixing_sublayer, "
                "self.attention_sublayer} must be nonempty"):
            encoder_block = layers.EncoderBlock(
                feed_forward_sublayer=feed_forward_layer,
                mixing_sublayer=None,
                attention_sublayer=None)
            _ = init_layer_variables(rng, encoder_block, init_batch)
Exemple #6
0
    def test_encoder_block_feed_forward(self):
        batch_size = 2
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)
        rng, init_rng = jax.random.split(rng)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0)
        mixing_layer = layers.LinearTransform()
        encoder_block = layers.EncoderBlock(
            feed_forward_sublayer=feed_forward_layer,
            mixing_sublayer=mixing_layer,
            attention_sublayer=None)
        input_emb = jax.random.uniform(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=10)
        input_ids = jax.random.randint(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=20)
        params = init_layer_variables(rng, encoder_block, {
            "input_emb": input_emb,
            "input_ids": input_ids
        })["params"]

        expected_keys = {
            "mixing_sublayer", "mixing_layer_norm", "output_layer_norm",
            "feed_forward_sublayer"
        }
        self.assertEqual(params.keys(), expected_keys)

        outputs = encoder_block.apply({"params": params},
                                      rngs={"dropout": rng},
                                      input_emb=input_emb,
                                      input_ids=input_ids)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))
Exemple #7
0
    def test_scatter_mask_dispatch_equal(self):
        batch_size = 4
        max_seq_length = 4
        hidden_dim = 2
        num_experts = 2
        tokens_per_group = 8
        num_groups = batch_size * max_seq_length // tokens_per_group

        rng = jax.random.PRNGKey(0)

        expert = layers.FeedForwardLayer(d_ff=2, dropout_rate=0.1, name="mlp")
        moe_layer_factory = functools.partial(
            layers.MoeLayer,
            num_experts=num_experts,
            max_group_size=tokens_per_group,
            train_capacity_factor=1.,
            eval_capacity_factor=1.,
            expert=expert,
            split_params=False)  # Ensures all experts start with same params

        router_weights = routing.RouterWeights(name="router_weights")
        masked_router = routing.TokensChooseMaskedRouter(
            router_weights=router_weights,
            jitter_noise=0.,
            num_selected_experts=2,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        masked_moe_layer = moe_layer_factory(router=masked_router)
        scatter_router = routing.TokensChooseScatterRouter(
            router_weights=router_weights,
            jitter_noise=0.,
            num_selected_experts=2,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        scatter_moe_layer = moe_layer_factory(router=scatter_router)

        input_emb = jax.random.uniform(
            rng, (batch_size, max_seq_length, hidden_dim),
            minval=-10,
            maxval=10)

        # Mock the router weights to ensure both layers compute with the same
        # logits.
        mock_router_logits = jax.random.uniform(
            rng, (num_groups, tokens_per_group, num_experts),
            minval=-1,
            maxval=1)
        with mock.patch.object(masked_router,
                               "router_weights",
                               return_value=mock_router_logits):
            masked_outputs, _ = masked_moe_layer.init_with_output(
                rng, input_emb, deterministic=True)
        with mock.patch.object(scatter_router,
                               "router_weights",
                               return_value=mock_router_logits):
            scatter_outputs, _ = scatter_moe_layer.init_with_output(
                rng, input_emb, deterministic=True)

        expected_outputs = jnp.array([
            [
                [-8.16194050e-04, -3.92473085e-05],
                [-8.87976727e-04, 6.41788647e-05],
                [1.51725704e-04, 5.44631148e-05],
                [0.00000000e+00, 0.00000000e+00],
            ],
            [
                [-1.63517136e-03, 7.32473345e-05],
                [6.99331111e-04, -4.98824847e-05],
                [-7.68527039e-04, -1.00117592e-04],
                [3.73630854e-03, 1.74387533e-04],
            ],
            [
                [1.09393802e-03, 5.09395104e-05],
                [-4.27273808e-05, 1.12514383e-04],
                [3.19827022e-03, 1.41921133e-04],
                [2.31421960e-04, -2.57078882e-05],
            ],
            [
                [0.00000000e+00, 0.00000000e+00],
                [1.65408337e-03, 1.62946199e-05],
                [2.29193736e-03, 1.07774074e-04],
                [-9.18464328e-04, -4.17242954e-05],
            ],
        ],
                                     dtype=jnp.float32)

        np.testing.assert_allclose(masked_outputs,
                                   expected_outputs,
                                   rtol=1e-6,
                                   atol=1e-6)
        np.testing.assert_allclose(scatter_outputs,
                                   expected_outputs,
                                   rtol=1e-6,
                                   atol=1e-6)
Exemple #8
0
    def test_encoder_block_switch(self):
        batch_size = 2
        max_seq_length = 14
        num_tokens = batch_size * max_seq_length
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)
        init_rng, dropout_key, jitter_key = jax.random.split(rng, num=3)

        expert = layers.FeedForwardLayer(d_ff=4, dropout_rate=0.0, name="mlp")
        router = routing.TokensChooseMaskedRouter(
            router_weights=routing.RouterWeights(name="router_weights"),
            jitter_noise=0.01,
            num_selected_experts=1,
            batch_prioritized_routing=True,
            dtype=jnp.float32)
        moe_layer = layers.MoeLayer(num_experts=2,
                                    router=router,
                                    max_group_size=num_tokens,
                                    train_capacity_factor=1.0,
                                    eval_capacity_factor=1.0,
                                    expert=expert)

        mixing_layer = layers.LinearTransform()
        encoder_block = layers.EncoderBlock(feed_forward_sublayer=moe_layer,
                                            mixing_sublayer=mixing_layer,
                                            attention_sublayer=None)
        input_emb = jax.random.uniform(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=10)
        input_ids = jax.random.randint(
            init_rng, (batch_size, max_seq_length, hidden_dim),
            minval=0,
            maxval=20)
        params = init_layer_variables(rng, encoder_block, {
            "input_emb": input_emb,
            "input_ids": input_ids
        })["params"]

        expected_keys = {
            "mixing_sublayer", "mixing_layer_norm", "output_layer_norm",
            "feed_forward_sublayer"
        }
        self.assertEqual(params.keys(), expected_keys)

        outputs, state = encoder_block.apply({"params": params},
                                             rngs={
                                                 "dropout": dropout_key,
                                                 "jitter": jitter_key
                                             },
                                             mutable=["intermediates"],
                                             input_emb=input_emb,
                                             input_ids=input_ids)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))

        self.assertIn("intermediates", state)
        jax.tree_util.tree_map(
            functools.partial(np.testing.assert_allclose, rtol=1e-5),
            state["intermediates"],
            FrozenDict({
                "feed_forward_sublayer": {
                    "diversity_metrics":
                    layers.DiversityMetrics(
                        auxiliary_loss=0.9997709,
                        router_z_loss=0.48709542,
                        fraction_tokens_left_behind=0.03571427,
                        expert_usage=0.96428573,
                        router_confidence=0.51779515)
                }
            }))