Beispiel #1
0
  def _init_mixing_sublayer(self, layer, model_arch,
                            mixing_key):
    """Initializes config-dependent mixing sublayer."""
    if model_arch == ModelArchitecture.BERT:
      mixing_sublayer = nn.SelfAttention(
          num_heads=self.config.num_heads,
          qkv_features=self.config.d_model,
          broadcast_dropout=False,
          kernel_init=default_kernel_init,
          bias_init=default_bias_init,
          dropout_rate=self.config.mixing_dropout_rate,
          use_bias=True,
          name=f"self_attention_{layer}")
    elif model_arch == ModelArchitecture.F_NET:
      mixing_sublayer = layers.FourierTransform(
          fourier_transform=self.fourier_transform,
          name=f"fourier_transform_{layer}")
    elif model_arch == ModelArchitecture.FF_ONLY:
      mixing_sublayer = layers.IdentityTransform(
          name=f"identity_transform_{layer}")
    elif model_arch == ModelArchitecture.LINEAR:
      mixing_sublayer = layers.LinearTransform(
          precision=lax.Precision.DEFAULT, name=f"linear_transform_{layer}")
    elif model_arch == ModelArchitecture.RANDOM:
      mixing_sublayer = layers.RandomTransform(
          max_seq_length=self.config.max_seq_length,
          d_model=self.config.d_model,
          key=mixing_key,
          precision=lax.Precision.DEFAULT,
          name=f"random_transform_{layer}")
    else:
      raise ValueError("Unexpected model architecture: %s" % model_arch.name)

    return mixing_sublayer
Beispiel #2
0
    def test_linear_transform(self):
        batch_size = 8
        max_seq_length = 16
        hidden_dim = 32
        rng = jax.random.PRNGKey(0)

        linear_layer = layers.LinearTransform()
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, linear_layer, init_batch)["params"]

        expected_keys = {"hidden_kernel", "seq_kernel"}
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = jax.random.randint(init_rng,
                                    (batch_size, max_seq_length, hidden_dim),
                                    minval=0,
                                    maxval=13)
        outputs = linear_layer.apply({"params": params}, inputs=inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))