Пример #1
0
    def setup(self):
        self.word_embeddings = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.hidden_size,
            embedding_init=get_kernel_init(self.config),
            name="word_embeddings",
        )
        self.position_embeddings = layers.PositionalEncoding(
            num_embeddings=self.config.max_position_embeddings,
            features=self.config.hidden_size,
            embedding_init=get_kernel_init(self.config),
            name="position_embeddings",
        )
        self.type_embeddings = nn.Embed(
            num_embeddings=self.config.type_vocab_size,
            features=self.config.hidden_size,
            embedding_init=get_kernel_init(self.config),
            name="type_embeddings",
        )
        self.embeddings_layer_norm = nn.LayerNorm(
            epsilon=self.config.layer_norm_eps, name="embeddings_layer_norm")
        self.embeddings_dropout = nn.Dropout(
            rate=self.config.hidden_dropout_prob)

        build_feed_forward = functools.partial(
            layers.FeedForward,
            d_model=self.config.hidden_size,
            d_ff=self.config.intermediate_size,
            intermediate_activation=get_hidden_activation(self.config),
            kernel_init=get_kernel_init(self.config),
        )
        build_self_attention = functools.partial(
            layers.SelfAttention,
            num_heads=self.config.num_attention_heads,
            qkv_features=self.config.hidden_size,
            dropout_rate=self.config.attention_probs_dropout_prob,
            broadcast_dropout=False,
            kernel_init=get_kernel_init(self.config),
            bias_init=nn.initializers.zeros,
        )
        self.encoder_layers = [
            layers.TransformerBlock(
                build_feed_forward=build_feed_forward,
                build_self_attention=build_self_attention,
                dropout_rate=self.config.hidden_dropout_prob,
                layer_norm_epsilon=self.config.layer_norm_eps,
                name=f"encoder_layer_{layer_num}",
            ) for layer_num in range(self.config.num_hidden_layers)
        ]
        self.pooler = nn.Dense(
            kernel_init=get_kernel_init(self.config),
            name="pooler",
            features=self.config.hidden_size,
        )
Пример #2
0
    def apply(self,
              input_ids,
              input_mask,
              type_ids,
              *,
              config,
              deterministic=False):
        """Applies BERT model on the inputs."""

        word_embeddings = nn.Embed(input_ids,
                                   num_embeddings=config.vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='word_embeddings')
        position_embeddings = layers.PositionalEncoding(
            word_embeddings,
            max_len=config.max_position_embeddings,
            posemb_init=get_kernel_init(config),
            name='position_embeddings')
        type_embeddings = nn.Embed(type_ids,
                                   num_embeddings=config.type_vocab_size,
                                   features=config.hidden_size,
                                   embedding_init=get_kernel_init(config),
                                   name='type_embeddings')

        embeddings = word_embeddings + position_embeddings + type_embeddings
        embeddings = nn.LayerNorm(embeddings,
                                  epsilon=LAYER_NORM_EPSILON,
                                  name='embeddings_layer_norm')
        embeddings = nn.dropout(embeddings,
                                rate=config.hidden_dropout_prob,
                                deterministic=deterministic)

        # Transformer blocks
        feed_forward = layers.FeedForward.partial(
            d_ff=config.intermediate_size,
            dropout_rate=config.hidden_dropout_prob,
            intermediate_activation=get_hidden_activation(config),
            kernel_init=get_kernel_init(config))

        attention = efficient_attention.BertSelfAttention.partial(
            num_heads=config.num_attention_heads,
            num_parallel_heads=None,
            d_qkv=config.hidden_size // config.num_attention_heads,
            attention_dropout_rate=config.attention_probs_dropout_prob,
            output_dropout_rate=config.hidden_dropout_prob,
            kernel_init=get_kernel_init(config),
            output_kernel_init=get_kernel_init(config))

        hidden_states = embeddings
        mask = input_mask.astype(jnp.int32)
        for layer_num in range(config.num_hidden_layers):
            hidden_states = layers.TransformerBlock(
                hidden_states,
                mask,
                feed_forward=feed_forward,
                attention=attention,
                deterministic=deterministic,
                name=f'encoder_layer_{layer_num}')

        pooled_output = nn.Dense(hidden_states[:, 0],
                                 config.hidden_size,
                                 kernel_init=get_kernel_init(config),
                                 name='pooler')
        pooled_output = jnp.tanh(pooled_output)

        return hidden_states, pooled_output