def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True):

        # Embed
        w_emb = FlaxRobertaEmbedding(
            self.vocab_size,
            self.hidden_size,
            kernel_init_scale=self.kernel_init_scale,
            name="word_embeddings",
            dtype=self.dtype,
        )(jnp.atleast_2d(input_ids.astype("i4")))
        p_emb = FlaxRobertaEmbedding(
            self.max_length,
            self.hidden_size,
            kernel_init_scale=self.kernel_init_scale,
            name="position_embeddings",
            dtype=self.dtype,
        )(jnp.atleast_2d(position_ids.astype("i4")))
        t_emb = FlaxRobertaEmbedding(
            self.type_vocab_size,
            self.hidden_size,
            kernel_init_scale=self.kernel_init_scale,
            name="token_type_embeddings",
            dtype=self.dtype,
        )(jnp.atleast_2d(token_type_ids.astype("i4")))

        # Sum all embeddings
        summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb

        # Layer Norm
        layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb)
        embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic)
        return embeddings
Beispiel #2
0
    def __call__(self, inputs, deterministic):
        """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      deterministic: if true dropout is applied otherwise not.

    Returns:
      output after transformer encoder block.
    """
        cfg = self.config

        # Attention block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
        x = nn.SelfAttention(num_heads=cfg.num_heads,
                             dtype=cfg.dtype,
                             qkv_features=cfg.qkv_dim,
                             kernel_init=cfg.kernel_init,
                             bias_init=cfg.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=cfg.attention_dropout_rate,
                             deterministic=deterministic)(x)

        x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=cfg.dtype)(x)
        y = MlpBlock(config=cfg)(y, deterministic=deterministic)
        return x + y
Beispiel #3
0
    def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None):
        """Applies EncoderDecoder1DBlock module.

    Args:
      inputs: input data for decoder
      decoder_mask: decoder self-attention mask.
      encoder_decoder_mask: encoder-decoder attention mask.

    Returns:
      output after transformer encoder-decoder block.
    """
        config = self.config

        # Decoder block.
        assert inputs.ndim == 3
        x = nn.LayerNorm(dtype=config.dtype)(inputs)
        x = nn.SelfAttention(num_heads=config.num_heads,
                             dtype=config.dtype,
                             qkv_features=config.qkv_dim,
                             kernel_init=config.kernel_init,
                             bias_init=config.bias_init,
                             use_bias=False,
                             broadcast_dropout=False,
                             dropout_rate=config.attention_dropout_rate,
                             deterministic=config.deterministic,
                             decode=config.decode)(x, decoder_mask)
        x = nn.Dropout(rate=config.dropout_rate)(
            x, deterministic=config.deterministic)
        x = x + inputs

        # MLP block.
        z = nn.LayerNorm(dtype=config.dtype)(x)
        z = MlpBlock(config=config)(z)

        return x + z
 def setup(self):
     self.dense = nn.Dense(
         self.config.hidden_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
         dtype=self.dtype,
     )
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    def setup(self) -> None:
        self.head_dim = self.embed_dim // self.num_heads

        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} "
                "and `num_heads`: {self.num_heads}).")

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=self.bias,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(self.config.init_std),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        self.dropout_layer = nn.Dropout(rate=self.dropout)

        if self.causal:
            self.causal_mask = make_causal_mask(jnp.ones(
                (1, self.config.max_position_embeddings), dtype="bool"),
                                                dtype="bool")
    def __call__(self, inputs, is_training: bool):
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = SelfAttentionBlock(num_heads=self.num_heads,
                               head_ch=self.head_ch,
                               out_ch=self.num_heads * self.head_ch,
                               dropout_rate=self.attn_dropout_rate,
                               dtype=self.dtype,
                               precision=self.precision,
                               kernel_init=self.kernel_init)(
                                   x, is_training=is_training)
        x = nn.Dropout(rate=self.dropout_rate)(x,
                                               deterministic=not is_training)

        x += inputs

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    dropout_rate=self.dropout_rate,
                    dtype=self.dtype,
                    precision=self.precision,
                    kernel_init=self.kernel_init,
                    bias_init=self.bias_init)(y, train=is_training)

        output = x + y
        return output
Beispiel #7
0
    def __call__(self, feat, lengths):
        """
        Compute Pooling by Multihead Attention.

        Parameters
        ----------
        feat : torch.Tensor
            The input feature.
        lengths : list
            The array of node numbers, used to segment feat tensor.

        Returns
        -------
        torch.Tensor
            The output feature
        """
        batch_size = len(lengths)
        query = self.seed_vectors.repeat(batch_size, axis=0)

        ffn_feat = nn.Dense(self.d_model)(
            nn.relu(
                nn.Dropout(self.dropouth)(
                    nn.Dense(self.d_ff)(feat)
                )
            )
        )

        return self.mha(query, ffn_feat, [self.k] * batch_size, lengths)
Beispiel #8
0
    def __call__(self, *, inputs, train):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      train: if it is training.

    Returns:
      output of a transformer decoder.

    """
        padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[...,
                                                                       None]
        assert inputs.ndim == 2  # (batch, len)

        cfg = self.config

        x = inputs.astype('int32')
        x = nn.Embed(num_embeddings=cfg.vocab_size,
                     features=cfg.emb_dim,
                     name='embed')(x)
        x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=not train)
        x = AddPositionEmbs(cfg)(x)

        for l in range(cfg.num_layers):
            x = Encoder1DBlock(cfg)(x, deterministic=not train)

        x = nn.LayerNorm(dtype=cfg.dtype)(x)
        logits = nn.Dense(cfg.output_vocab_size,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init)(x)
        return logits
Beispiel #9
0
    def setup(self):
        if self.config.hidden_size % self.config.num_attention_heads != 0:
            raise ValueError(
                "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\
                    : {self.config.num_attention_heads}")

        self.query = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )
        self.key = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )
        self.value = nn.Dense(
            self.config.hidden_size,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )
        self.dense = nn.Dense(
            self.config.hidden_size,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            dtype=self.dtype,
        )
        self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,
                                      dtype=self.dtype)
        self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
 def setup(self):
     self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype)
     classifier_dropout = (self.config.classifier_dropout
                           if self.config.classifier_dropout is not None
                           else self.config.hidden_dropout_prob)
     self.dropout = nn.Dropout(classifier_dropout)
     self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
    def setup(self):
        config = self.config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.rotary_dim = config.rotary_dim

        dense = partial(
            nn.Dense,
            self.embed_dim,
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
        )

        self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense()
        self.out_proj = dense()

        self.resid_dropout = nn.Dropout(rate=config.resid_pdrop)

        self.causal_mask = make_causal_mask(jnp.ones(
            (1, config.max_position_embeddings), dtype="bool"),
                                            dtype="bool")

        pos_embd_dim = self.rotary_dim or self.embed_dim
        self.embed_positions = create_sinusoidal_positions(
            config.max_position_embeddings, pos_embd_dim)
Beispiel #12
0
    def setup(self):
        self.n_heads = self.config.n_heads
        self.dim = self.config.dim
        self.dropout = nn.Dropout(rate=self.config.attention_dropout)

        if not (self.dim % self.n_heads == 0):
            raise ValueError(
                f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}"
            )

        self.q_lin = nn.Dense(
            self.dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
        )
        self.k_lin = nn.Dense(
            self.dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
        )
        self.v_lin = nn.Dense(
            self.dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
        )
        self.out_lin = nn.Dense(
            self.dim,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(
                stddev=self.config.initializer_range),
        )
 def setup(self):
     self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype)
     classifier_dropout = (self.config.classifier_dropout
                           if self.config.classifier_dropout is not None
                           else self.config.hidden_dropout_prob)
     self.dropout = nn.Dropout(classifier_dropout)
     self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
Beispiel #14
0
 def setup(self):
     self.bert = FlaxBertModule(config=self.config, dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
     self.classifier = nn.Dense(
         self.config.num_labels,
         dtype=self.dtype,
     )
Beispiel #15
0
 def setup(self):
     self.word_embeddings = nn.Embed(
         self.config.vocab_size,
         self.config.hidden_size,
         embedding_init=jax.nn.initializers.normal(
             stddev=self.config.initializer_range),
         dtype=self.dtype,
     )
     self.position_embeddings = nn.Embed(
         self.config.max_position_embeddings,
         self.config.hidden_size,
         embedding_init=jax.nn.initializers.normal(
             stddev=self.config.initializer_range),
         dtype=self.dtype,
     )
     self.token_type_embeddings = nn.Embed(
         self.config.type_vocab_size,
         self.config.hidden_size,
         embedding_init=jax.nn.initializers.normal(
             stddev=self.config.initializer_range),
         dtype=self.dtype,
     )
     self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,
                                   dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
    def __call__(self, graph, train):
        dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train)

        graph = graph._replace(
            globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]))

        embedder = jraph.GraphMapFeatures(
            embed_node_fn=_make_embed(self.latent_dim),
            embed_edge_fn=_make_embed(self.latent_dim))
        graph = embedder(graph)

        for _ in range(self.num_message_passing_steps):
            net = jraph.GraphNetwork(update_edge_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_node_fn=_make_mlp(self.hidden_dims,
                                                              dropout=dropout),
                                     update_global_fn=_make_mlp(
                                         self.hidden_dims, dropout=dropout))

            graph = net(graph)

        # Map globals to represent the final result
        decoder = jraph.GraphMapFeatures(
            embed_global_fn=nn.Dense(self.num_outputs))
        graph = decoder(graph)

        return graph.globals
    def __call__(self, inputs_q, inputs_kv, is_training: bool):
        assert inputs_q.ndim == inputs_kv.ndim == 3

        in_ch = inputs_q.shape[-1]
        assert in_ch % self.num_heads == 0
        head_ch = self.head_ch or int(in_ch / self.num_heads)
        out_ch = self.out_ch or in_ch

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(self.num_heads, head_ch),
                        use_bias=self.use_bias,
                        dtype=self.dtype)

        query = dense(name='queries')(inputs_q)
        key = dense(name='keys')(inputs_kv)
        value = dense(name='values')(inputs_kv)

        query = query / jnp.sqrt(head_ch)

        attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query,
                                  key)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.softmax(attn_weights)

        if self.talking_heads:
            attn_weights = TalkingHeadsBlock(
                num_heads=self.num_heads)(attn_weights)

        attn_weights = nn.Dropout(rate=self.attn_dropout_rate)(
            attn_weights, deterministic=not is_training)

        attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d',
                                 attn_weights, value)

        output = nn.DenseGeneral(features=out_ch,
                                 axis=(-2, -1),
                                 use_bias=self.use_bias,
                                 dtype=self.dtype)(attn_scores)

        output = nn.Dropout(rate=self.out_dropout_rate)(
            output, deterministic=not is_training)
        return output
    def __call__(self,
                 targets,
                 encoded,
                 decoder_mask=None,
                 encoder_decoder_mask=None):
        """Applies Transformer to decode the targets.

    Args:
      targets: target outputs.
      encoded: encoded input data from encoder [batch, ..., length, mlp_dim].
      decoder_mask: decoder self-attention mask
      encoder_decoder_mask: encoder-decoder attention mask

    Returns:
      output of a transformer decoder.
    """
        cfg = self.config

        assert encoded.ndim == targets.ndim + 1

        output_embed = nn.Embed(
            num_embeddings=cfg.output_vocab_size,
            features=cfg.emb_dim,
            embedding_init=nn.initializers.normal(stddev=1.0),
            name='embed_output')

        heads = dict()
        y = targets.astype('int32')
        if cfg.shift:
            y = shift_right(y, cfg.bos_token)

        y = output_embed(y)
        y = AddPositionEmbs(config=cfg,
                            cache=cfg.decode,
                            name='posembed_output')(y)
        y = nn.Dropout(rate=cfg.dropout_rate)(y,
                                              deterministic=cfg.deterministic)

        y = y.astype(cfg.dtype)
        # Target-Input Decoder
        for lyr in range(cfg.num_layers):
            y = EncoderDecoderBlock(config=cfg,
                                    name=f'encoderdecoderblock_{lyr}')(
                                        y, encoded, decoder_mask,
                                        encoder_decoder_mask)
        y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y)

        heads['output_emb'] = y * (jnp.where(targets > 0, 1, 0).astype(
            jnp.float32)[Ellipsis, None])

        logits = nn.Dense(cfg.output_vocab_size,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init,
                          name='logitdense')(y)
        heads['logits'] = logits
        if cfg.output_head:
            return heads[cfg.output_head]
        else:
            return heads  # Return both output embeddings and logits.
Beispiel #19
0
 def __call__(self, inputs, train):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[
         -1] if self.out_dim is None else self.out_dim
     x = nn.Dense(self.mlp_dim,
                  dtype=self.dtype,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(inputs)
     x = nn.relu(x)
     x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
     output = nn.Dense(actual_out_dim,
                       dtype=self.dtype,
                       kernel_init=self.kernel_init,
                       bias_init=self.bias_init)(x)
     output = nn.Dropout(rate=self.dropout_rate,
                         deterministic=not train)(output)
     return output
Beispiel #20
0
    def setup(self):
        self.encoder = encoder_registry.get_registered_encoder(
            self.encoder_name)(**self.encoder_config)

        if self.apply_mlp:
            self.mlp = nn.Dense(self.encoder_config.hidden_size, self.dtype)
            self.dropout = nn.Dropout(self.encoder_config.dropout_rate)
        self.linear_classifier = nn.Dense(self.vocab_size, dtype=self.dtype)
Beispiel #21
0
 def setup(self):
     self.dense = nn.Dense(
         self.config.hidden_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
         dtype=self.dtype,
     )
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
     self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
Beispiel #22
0
 def setup(self):
     self.dropout_layer = nn.Dropout(rate=self.dropout_rate)
     self.keys_only_mlp_attention = KeysOnlyMlpAttention(
         hidden_size=self.hidden_size)
     self.mlp = MLP(hidden_size=self.hidden_size,
                    output_size=self.output_size,
                    output_bias=False,
                    dropout_rate=self.dropout_rate)
Beispiel #23
0
 def setup(self):
     self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.final_dropout)
     self.lm_head = nn.Dense(
         self.config.vocab_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype),
         dtype=self.dtype,
     )
 def setup(self):
     self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size))
     self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype)
     num_patches = self.patch_embeddings.num_patches
     self.position_embeddings = self.param(
         "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size)
     )
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Beispiel #25
0
 def __call__(self, inputs):
     x = inputs
     for size in self.feature_sizes:
         x = nn.Dense(features=size)(x)
         x = self.activation(x)
         x = nn.Dropout(rate=self.dropout_rate,
                        deterministic=self.deterministic)(x)
     return x
Beispiel #26
0
 def __call__(self, inputs, deterministic=True):
     """Applies Transformer MlpBlock module."""
     cfg = self.config
     actual_out_dim = (inputs.shape[-1]
                       if self.out_dim is None else self.out_dim)
     x = nn.Dense(cfg.mlp_dim,
                  dtype=cfg.dtype,
                  kernel_init=cfg.kernel_init,
                  bias_init=cfg.bias_init)(inputs)
     x = nn.elu(x)
     x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
     output = nn.Dense(actual_out_dim,
                       dtype=cfg.dtype,
                       kernel_init=cfg.kernel_init,
                       bias_init=cfg.bias_init)(x)
     output = nn.Dropout(rate=cfg.dropout_rate)(output,
                                                deterministic=deterministic)
     return output
 def setup(self):
     self.roberta = FlaxRobertaModule(config=self.config,
                                      dtype=self.dtype,
                                      add_pooling_layer=False)
     classifier_dropout = (self.config.classifier_dropout
                           if self.config.classifier_dropout is not None
                           else self.config.hidden_dropout_prob)
     self.dropout = nn.Dropout(rate=classifier_dropout)
     self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
Beispiel #28
0
  def __call__(self, input_qkv):
    cfg = self.config
    cfg.max_len % cfg.max_seg_len == 0
    bsize = input_qkv.shape[0]
    features = self.out_features or input_qkv.shape[-1]
    num_seg = cfg.max_len // cfg.max_seg_len
    x_sqr = input_qkv.reshape([bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]])
    q_row_local, key_row_local, value_row_local, head_dim = get_qkv(cfg, x_sqr)
    local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local, key_row_local)
    row_probs = jax.nn.softmax(local_logits)
    if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')
      row_probs = dropatt(row_probs, dropout_rng, 1 - cfg.attention_dropout_rate)
    row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs, value_row_local)

    key_row = DenseGeneral(features=input_qkv.shape[-1],
                           axis=(-2, -1),
                           kernel_init=cfg.kernel_init,
                           bias_init=cfg.bias_init,
                           use_bias=False,
                           dtype=cfg.dtype)(row_attn_out)
    key_row = nn.Dropout(rate=cfg.dropout_rate)(key_row, deterministic=cfg.deterministic)
    key_row = key_row + x_sqr
    key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row)
    key_row = DenseGeneral(axis=-1,
                           features=(cfg.num_heads, head_dim),
                           kernel_init=cfg.kernel_init,
                           bias_init=cfg.bias_init,
                           use_bias=False,
                           dtype=cfg.dtype)(key_row)
    idx_cols = jnp.arange(cfg.max_seg_len)
    local_mask = nn.make_attention_mask(idx_cols, idx_cols, jnp.less, extra_batch_dims=1)
    local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10
    local_logits = local_logits + local_mask

    global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row)
    idx_rows = jnp.arange(num_seg)
    global_mask = nn.make_attention_mask(idx_rows, idx_rows, jnp.less_equal)
    global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10
    global_logits = global_logits + global_mask

    joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1)
    attn_probs = jax.nn.softmax(joint_logits, axis=-1)
    local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len], axis=-1)
    if not cfg.deterministic and cfg.attention_dropout_rate > 0.:
      dropout_rng = self.make_rng('dropout')
      local_att = dropatt(local_att, dropout_rng, 1 - cfg.attention_dropout_rate)
    local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local)
    global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out)
    joint_merged = jnp.reshape(local_merged + global_merged, [bsize, cfg.max_len, cfg.num_heads, head_dim])
    x = DenseGeneral(features=features,
                  axis=(-2, -1),
                  kernel_init=cfg.kernel_init,
                  bias_init=cfg.bias_init,
                  use_bias=False,
                  dtype=cfg.dtype)(joint_merged)
    return x
Beispiel #29
0
 def setup(self):
     self.mlp = nn.Dense(features=self.hidden_dim,
                         kernel_init=self.kernel_init,
                         bias_init=self.bias_init)
     self.dense = nn.Dense(features=self.input_dim,
                           kernel_init=self.kernel_init,
                           bias_init=self.bias_init)
     self.dropout = nn.Dropout(self.dropout_rate)
     self.layer_norm = nn.LayerNorm(epsilon=self.layer_norm_epsilon)
    def setup(self):
        embed_dim = self.config.hidden_size
        kernel_init = jax.nn.initializers.normal(self.config.initializer_range)

        self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init)
        self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init)

        self.act = ACT2FN[self.config.activation_function]
        self.dropout = nn.Dropout(rate=self.config.resid_pdrop)