コード例 #1
0
    def test_hybrid_encoder(self, attention_layout, num_attention_layers,
                            expected_attention_layers):
        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.num_layers = 4
            config.attention_layout = attention_layout
            config.num_attention_layers = num_attention_layers
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, encoder, init_batch)

        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "encoder_2", "encoder_3",
            "feed_forward_0", "feed_forward_1", "feed_forward_2",
            "feed_forward_3", "pooler"
        }
        for expected_attention_layer in expected_attention_layers:
            expected_keys.add(f"self_attention_{expected_attention_layer}")

        self.assertEqual(params.keys(), expected_keys)

        inputs = dummy_inputs(rng, config)
        hidden_states, pooled_output = encoder.apply({"params": params},
                                                     rngs={"dropout": rng},
                                                     **inputs)
        expected_hidden_states_shape = (config.train_batch_size,
                                        config.max_seq_length, config.d_model)
        self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
コード例 #2
0
    def test_unparametrized_mixing_encoder(self, model_arch):
        config = dummy_config(model_arch=model_arch)
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)
        params = init_model_params(rng, encoder, init_batch)
        # Unparameterized mixing encoders do not have any parameters in their mixing
        # layers, so their mixing layer names do not show up in params.
        expected_keys = {
            "embedder", "encoder_0", "encoder_1", "feed_forward_0",
            "feed_forward_1", "pooler"
        }
        self.assertEqual(params.keys(), expected_keys)

        inputs = dummy_inputs(rng, config)
        hidden_states, pooled_output = encoder.apply({"params": params},
                                                     rngs={"dropout": rng},
                                                     **inputs)
        expected_hidden_states_shape = (config.train_batch_size,
                                        config.max_seq_length, config.d_model)
        self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
        expected_pooled_output_shape = (config.train_batch_size,
                                        config.d_model)
        self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
コード例 #3
0
  def test_f_net_encoder_bad_long_seq(self):
    config = dummy_frozen_config(
        model_arch=base_config.ModelArchitecture.F_NET,
        use_tpu_fourier_optimizations=True,
        max_seq_length=8194)
    encoder = models.EncoderModel(config=config)
    rng = jax.random.PRNGKey(0)
    init_batch = init_encoder_batch(config)

    with self.assertRaisesRegex(
        ValueError,
        "must be a power of 2 to take advantage of FFT optimizations"):
      _ = init_model_params(rng, encoder, init_batch)
コード例 #4
0
    def test_f_net_encoder_bad_long_seq(self):
        config = dummy_config(model_arch=ModelArchitecture.F_NET)
        with config.unlocked():
            config.max_seq_length = 8194
        frozen_config = ml_collections.FrozenConfigDict(config)

        encoder = models.EncoderModel(config=frozen_config)

        rng = jax.random.PRNGKey(0)
        init_batch = init_encoder_batch(config)

        with self.assertRaisesRegex(
                ValueError,
                "must be a power of 2 to take advantage of FFT optimizations"):
            _ = init_model_params(rng, encoder, init_batch)
コード例 #5
0
  def test_bert_encoder(self):
    config = dummy_frozen_config(model_arch=base_config.ModelArchitecture.BERT)
    encoder = models.EncoderModel(config=config)
    rng = jax.random.PRNGKey(0)
    init_batch = init_encoder_batch(config)
    params = init_model_params(rng, encoder, init_batch)
    expected_keys = {
        "embedder", "encoder_0", "encoder_1", "feed_forward_0",
        "feed_forward_1", "self_attention_0", "self_attention_1", "pooler"
    }
    self.assertEqual(params.keys(), expected_keys)

    inputs = dummy_inputs(rng, config)
    hidden_states, pooled_output = encoder.apply({"params": params},
                                                 rngs={"dropout": rng},
                                                 **inputs)
    expected_hidden_states_shape = (config.train_batch_size,
                                    config.max_seq_length, config.d_model)
    self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
    expected_pooled_output_shape = (config.train_batch_size, config.d_model)
    self.assertEqual(pooled_output.shape, expected_pooled_output_shape)
コード例 #6
0
def create_optimizer(key, config):
    """Creates optimizer for models.EncoderModel."""
    model = models.EncoderModel(config=config)

    init_batch = {
        "input_ids": jnp.ones((1, config.max_seq_length), jnp.int32),
        "input_mask": jnp.ones((1, config.max_seq_length), jnp.int32),
        "type_ids": jnp.ones((1, config.max_seq_length), jnp.int32)
    }

    key, dropout_key = jax.random.split(key)

    jit_init = jax.jit(model.init)
    initial_variables = jit_init({
        "params": key,
        "dropout": dropout_key
    }, **init_batch)
    params = initial_variables["params"]

    optimizer_def = optim.Adam(learning_rate=1e-4)
    return optimizer_def.create(params)
コード例 #7
0
  def test_f_net_encoder_short_seq(self):
    config = dummy_frozen_config(
        model_arch=base_config.ModelArchitecture.F_NET, max_seq_length=16)
    encoder = models.EncoderModel(config=config)
    rng = jax.random.PRNGKey(0)
    init_batch = init_encoder_batch(config)
    params = init_model_params(rng, encoder, init_batch)
    # Fourier sublayers have no parameters so do not show up in params.
    expected_keys = {
        "embedder", "encoder_0", "encoder_1", "feed_forward_0",
        "feed_forward_1", "pooler"
    }
    self.assertEqual(params.keys(), expected_keys)

    inputs = dummy_inputs(rng, config)
    hidden_states, pooled_output = encoder.apply({"params": params},
                                                 rngs={"dropout": rng},
                                                 **inputs)
    expected_hidden_states_shape = (config.train_batch_size,
                                    config.max_seq_length, config.d_model)
    self.assertEqual(hidden_states.shape, expected_hidden_states_shape)
    expected_pooled_output_shape = (config.train_batch_size, config.d_model)
    self.assertEqual(pooled_output.shape, expected_pooled_output_shape)