Esempio n. 1
0
    def test_transformer_block_shape(self):
        """Testing transformer block shape."""

        encoding = jnp.ones(shape=(self.bsz, self.seq_len, self.model_dim),
                            dtype=self.dtype)

        attention_mask = jnp.ones(shape=(self.bsz, self.seq_len),
                                  dtype=self.dtype)

        model = transformer.TransformerBlock(
            num_layers=self.num_layers,
            model_dim=self.model_dim,
            intermediate_dim=self.intermediate_dim,
            num_heads=self.num_heads,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
        )

        rng = jax.random.PRNGKey(0)
        output, _ = model.init_with_output(
            rng,
            encoding=encoding,
            attention_mask=attention_mask,
            deterministic=True,
        )

        self.assertSequenceEqual(output.shape, encoding.shape)
Esempio n. 2
0
 def make_transformer_block(n_layers: int):
   return transformer.TransformerBlock(
       num_layers=n_layers,
       model_dim=self.hidden_size,
       intermediate_dim=self.intermediate_dim,
       num_heads=self.num_attention_heads,
       dropout_rate=self.dropout_rate,
       dtype=self.dtype,
       kernel_init=self.kernel_init,
       bias_init=self.bias_init,
       layer_norm_epsilon=self.layer_norm_epsilon,
   )
Esempio n. 3
0
    def setup(self):

        self.embedder = embedding.DictEmbed({
            'token_ids':
            embedding.Embed(
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'position_ids':
            embedding.Embed(
                num_embeddings=self.max_positions,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'segment_ids':
            embedding.Embed(
                num_embeddings=self.num_segments,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            )
        })

        self.embeddings_layer_norm = nn.LayerNorm(
            epsilon=self.layer_norm_epsilon)
        self.embeddings_dropout = nn.Dropout(rate=self.dropout_rate)

        self.encoder = transformer.TransformerBlock(
            num_layers=self.num_layers,
            model_dim=self.hidden_size,
            intermediate_dim=self.intermediate_dim,
            num_heads=self.num_attention_heads,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            layer_norm_epsilon=self.layer_norm_epsilon,
        )

        self.mention_projector = nn.Dense(
            features=self.mention_encoding_dim,
            dtype=self.dtype,
        )
Esempio n. 4
0
    def setup(self):

        self.embedder = embedding.DictEmbed({
            'token_ids':
            embedding.Embed(
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'position_ids':
            embedding.Embed(
                num_embeddings=self.max_positions,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'segment_ids':
            embedding.Embed(
                num_embeddings=self.num_segments,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            )
        })

        self.embeddings_layer_norm = nn.LayerNorm(
            epsilon=self.layer_norm_epsilon)
        self.embeddings_dropout = nn.Dropout(rate=self.dropout_rate)

        self.initial_encoder = transformer.TransformerBlock(
            num_layers=self.num_initial_layers,
            model_dim=self.hidden_size,
            intermediate_dim=self.intermediate_dim,
            num_heads=self.num_attention_heads,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            layer_norm_epsilon=self.layer_norm_epsilon,
        )

        self.retrieval_update_layer = retrieval_update_layers.RETRIEVAL_UPDATE_REGISTRY[
            self.retrieval_update_type](
                input_dim=self.hidden_size,
                dtype=self.dtype,
                layer_norm_epsilon=self.layer_norm_epsilon,
                **self.retrieval_update_config,
            )

        self.final_encoder = transformer.TransformerBlock(
            num_layers=self.num_final_layers,
            model_dim=self.hidden_size,
            intermediate_dim=self.intermediate_dim,
            num_heads=self.num_attention_heads,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            layer_norm_epsilon=self.layer_norm_epsilon,
        )

        self.mention_projector = nn.Dense(
            features=self.retrieval_dim,
            dtype=self.dtype,
        )