Esempio n. 1
0
  def _init_mixing_sublayer(self, layer, model_arch,
                            mixing_key):
    """Initializes config-dependent mixing sublayer."""
    if model_arch == ModelArchitecture.BERT:
      mixing_sublayer = nn.SelfAttention(
          num_heads=self.config.num_heads,
          qkv_features=self.config.d_model,
          broadcast_dropout=False,
          kernel_init=default_kernel_init,
          bias_init=default_bias_init,
          dropout_rate=self.config.mixing_dropout_rate,
          use_bias=True,
          name=f"self_attention_{layer}")
    elif model_arch == ModelArchitecture.F_NET:
      mixing_sublayer = layers.FourierTransform(
          fourier_transform=self.fourier_transform,
          name=f"fourier_transform_{layer}")
    elif model_arch == ModelArchitecture.FF_ONLY:
      mixing_sublayer = layers.IdentityTransform(
          name=f"identity_transform_{layer}")
    elif model_arch == ModelArchitecture.LINEAR:
      mixing_sublayer = layers.LinearTransform(
          precision=lax.Precision.DEFAULT, name=f"linear_transform_{layer}")
    elif model_arch == ModelArchitecture.RANDOM:
      mixing_sublayer = layers.RandomTransform(
          max_seq_length=self.config.max_seq_length,
          d_model=self.config.d_model,
          key=mixing_key,
          precision=lax.Precision.DEFAULT,
          name=f"random_transform_{layer}")
    else:
      raise ValueError("Unexpected model architecture: %s" % model_arch.name)

    return mixing_sublayer
Esempio n. 2
0
    def test_encoder_block(self):
        batch_size = 2
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0)
        mixing_layer = layers.IdentityTransform()
        encoder_block = layers.EncoderBlock(
            feed_forward_sublayer=feed_forward_layer,
            mixing_sublayer=mixing_layer)
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32),
            "padding_mask": jnp.ones((1, max_seq_length), jnp.int32)
        }
        params = init_layer_variables(rng, encoder_block, init_batch)["params"]

        expected_keys = {
            "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer"
        }
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = {
            "inputs":
            jax.random.randint(init_rng,
                               (batch_size, max_seq_length, hidden_dim),
                               minval=0,
                               maxval=10),
            "padding_mask":
            jax.random.randint(init_rng, (batch_size, max_seq_length),
                               minval=0,
                               maxval=1)
        }

        outputs = encoder_block.apply({"params": params},
                                      rngs={"dropout": rng},
                                      **inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))
Esempio n. 3
0
    def test_identity_transform(self):
        batch_size = 8
        max_seq_length = 8
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        identity_layer = layers.IdentityTransform()
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        init_layer_variables(rng, identity_layer, init_batch)

        rng, init_rng = jax.random.split(rng)
        inputs = jax.random.randint(init_rng,
                                    (batch_size, max_seq_length, hidden_dim),
                                    minval=0,
                                    maxval=10)
        # IdentityTransform layer has no learnable params.
        outputs = identity_layer.apply({"params": {}}, inputs=inputs)

        # Inputs are unchanged by IdentityTransform layer.
        np.testing.assert_allclose(outputs, inputs)