コード例 #1
0
    def test_feed_forward_layer(self):
        batch_size = 3
        max_seq_length = 16
        hidden_dim = 12
        rng = jax.random.PRNGKey(0)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.1)
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32)
        }
        params = init_layer_variables(rng, feed_forward_layer,
                                      init_batch)["params"]

        expected_keys = {"intermediate", "output"}
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = jax.random.randint(init_rng,
                                    (batch_size, max_seq_length, hidden_dim),
                                    minval=0,
                                    maxval=10)
        outputs = feed_forward_layer.apply({"params": params},
                                           rngs={"dropout": rng},
                                           inputs=inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))
コード例 #2
0
ファイル: models.py プロジェクト: lucifer2288/google-research
  def setup(self):
    """Initializes encoder with config-dependent mixing layer."""
    if self.config.model_arch == ModelArchitecture.F_NET:
      self._init_fourier_transform()

    # Random number generator key for RANDOM model architecture.
    key = random.PRNGKey(self.random_seed)

    encoder_blocks = []  # Attributes are immutable so use temporary list
    for layer in range(self.config.num_layers):
      key, mixing_key = random.split(key)
      mixing_arch = ModelArchitecture.BERT if self._is_attention_layer(
          layer) else self.config.model_arch
      mixing_layer = self._init_mixing_sublayer(layer, mixing_arch, mixing_key)
      feed_forward_layer = layers.FeedForwardLayer(
          d_ff=self.config.d_ff,
          dropout_rate=self.config.dropout_rate,
          name=f"feed_forward_{layer}")
      encoder_blocks.append(
          layers.EncoderBlock(
              mixing_sublayer=mixing_layer,
              feed_forward_sublayer=feed_forward_layer,
              name=f"encoder_{layer}"))
    self.encoder_blocks = encoder_blocks

    self.embedder = layers.EmbeddingLayer(config=self.config, name="embedder")

    self.pooler = nn.Dense(
        self.config.d_model, kernel_init=default_kernel_init, name="pooler")
コード例 #3
0
    def test_encoder_block(self):
        batch_size = 2
        max_seq_length = 14
        hidden_dim = 8
        rng = jax.random.PRNGKey(0)

        feed_forward_layer = layers.FeedForwardLayer(d_ff=8, dropout_rate=0.0)
        mixing_layer = layers.IdentityTransform()
        encoder_block = layers.EncoderBlock(
            feed_forward_sublayer=feed_forward_layer,
            mixing_sublayer=mixing_layer)
        init_batch = {
            "inputs": jnp.ones((1, max_seq_length, hidden_dim), jnp.float32),
            "padding_mask": jnp.ones((1, max_seq_length), jnp.int32)
        }
        params = init_layer_variables(rng, encoder_block, init_batch)["params"]

        expected_keys = {
            "mixing_layer_norm", "output_layer_norm", "feed_forward_sublayer"
        }
        self.assertEqual(params.keys(), expected_keys)

        rng, init_rng = jax.random.split(rng)
        inputs = {
            "inputs":
            jax.random.randint(init_rng,
                               (batch_size, max_seq_length, hidden_dim),
                               minval=0,
                               maxval=10),
            "padding_mask":
            jax.random.randint(init_rng, (batch_size, max_seq_length),
                               minval=0,
                               maxval=1)
        }

        outputs = encoder_block.apply({"params": params},
                                      rngs={"dropout": rng},
                                      **inputs)
        self.assertEqual(outputs.shape,
                         (batch_size, max_seq_length, hidden_dim))