def __call__(self, input_ids, token_type_ids, position_ids, attention_mask, deterministic: bool = True): # Embed w_emb = FlaxRobertaEmbedding( self.vocab_size, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="word_embeddings", dtype=self.dtype, )(jnp.atleast_2d(input_ids.astype("i4"))) p_emb = FlaxRobertaEmbedding( self.max_length, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="position_embeddings", dtype=self.dtype, )(jnp.atleast_2d(position_ids.astype("i4"))) t_emb = FlaxRobertaEmbedding( self.type_vocab_size, self.hidden_size, kernel_init_scale=self.kernel_init_scale, name="token_type_embeddings", dtype=self.dtype, )(jnp.atleast_2d(token_type_ids.astype("i4"))) # Sum all embeddings summed_emb = w_emb + jnp.broadcast_to(p_emb, w_emb.shape) + t_emb # Layer Norm layer_norm = FlaxRobertaLayerNorm(name="layer_norm", dtype=self.dtype)(summed_emb) embeddings = nn.Dropout(rate=self.dropout_rate)(layer_norm, deterministic=deterministic) return embeddings
def __call__(self, inputs, deterministic): """Applies Encoder1DBlock module. Args: inputs: input data. deterministic: if true dropout is applied otherwise not. Returns: output after transformer encoder block. """ cfg = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention(num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=deterministic)(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg)(y, deterministic=deterministic) return x + y
def __call__(self, inputs, decoder_mask=None, encoder_decoder_mask=None): """Applies EncoderDecoder1DBlock module. Args: inputs: input data for decoder decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. Returns: output after transformer encoder-decoder block. """ config = self.config # Decoder block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=config.dtype)(inputs) x = nn.SelfAttention(num_heads=config.num_heads, dtype=config.dtype, qkv_features=config.qkv_dim, kernel_init=config.kernel_init, bias_init=config.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=config.attention_dropout_rate, deterministic=config.deterministic, decode=config.decode)(x, decoder_mask) x = nn.Dropout(rate=config.dropout_rate)( x, deterministic=config.deterministic) x = x + inputs # MLP block. z = nn.LayerNorm(dtype=config.dtype)(x) z = MlpBlock(config=config)(z) return x + z
def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def setup(self) -> None: self.head_dim = self.embed_dim // self.num_heads if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} " "and `num_heads`: {self.num_heads}).") dense = partial( nn.Dense, self.embed_dim, use_bias=self.bias, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.init_std), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.dropout_layer = nn.Dropout(rate=self.dropout) if self.causal: self.causal_mask = make_causal_mask(jnp.ones( (1, self.config.max_position_embeddings), dtype="bool"), dtype="bool")
def __call__(self, inputs, is_training: bool): x = nn.LayerNorm(dtype=self.dtype)(inputs) x = SelfAttentionBlock(num_heads=self.num_heads, head_ch=self.head_ch, out_ch=self.num_heads * self.head_ch, dropout_rate=self.attn_dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init)( x, is_training=is_training) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) x += inputs y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, dropout_rate=self.dropout_rate, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(y, train=is_training) output = x + y return output
def __call__(self, feat, lengths): """ Compute Pooling by Multihead Attention. Parameters ---------- feat : torch.Tensor The input feature. lengths : list The array of node numbers, used to segment feat tensor. Returns ------- torch.Tensor The output feature """ batch_size = len(lengths) query = self.seed_vectors.repeat(batch_size, axis=0) ffn_feat = nn.Dense(self.d_model)( nn.relu( nn.Dropout(self.dropouth)( nn.Dense(self.d_ff)(feat) ) ) ) return self.mha(query, ffn_feat, [self.k] * batch_size, lengths)
def __call__(self, *, inputs, train): """Applies Transformer model on the inputs. Args: inputs: input data train: if it is training. Returns: output of a transformer decoder. """ padding_mask = jnp.where(inputs > 0, 1, 0).astype(jnp.float32)[..., None] assert inputs.ndim == 2 # (batch, len) cfg = self.config x = inputs.astype('int32') x = nn.Embed(num_embeddings=cfg.vocab_size, features=cfg.emb_dim, name='embed')(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=not train) x = AddPositionEmbs(cfg)(x) for l in range(cfg.num_layers): x = Encoder1DBlock(cfg)(x, deterministic=not train) x = nn.LayerNorm(dtype=cfg.dtype)(x) logits = nn.Dense(cfg.output_vocab_size, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(x) return logits
def setup(self): if self.config.hidden_size % self.config.num_attention_heads != 0: raise ValueError( "`config.hidden_size`: {self.config.hidden_size} has to be a multiple of `config.num_attention_heads`\ : {self.config.num_attention_heads}") self.query = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.key = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.value = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def setup(self): self.electra = FlaxElectraModule(config=self.config, dtype=self.dtype) classifier_dropout = (self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def setup(self): config = self.config self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads self.rotary_dim = config.rotary_dim dense = partial( nn.Dense, self.embed_dim, use_bias=False, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), ) self.q_proj, self.k_proj, self.v_proj = dense(), dense(), dense() self.out_proj = dense() self.resid_dropout = nn.Dropout(rate=config.resid_pdrop) self.causal_mask = make_causal_mask(jnp.ones( (1, config.max_position_embeddings), dtype="bool"), dtype="bool") pos_embd_dim = self.rotary_dim or self.embed_dim self.embed_positions = create_sinusoidal_positions( config.max_position_embeddings, pos_embd_dim)
def setup(self): self.n_heads = self.config.n_heads self.dim = self.config.dim self.dropout = nn.Dropout(rate=self.config.attention_dropout) if not (self.dim % self.n_heads == 0): raise ValueError( f"Hidden size {self.dim} not dividable by number of heads {self.n_heads}" ) self.q_lin = nn.Dense( self.dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), ) self.k_lin = nn.Dense( self.dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), ) self.v_lin = nn.Dense( self.dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), ) self.out_lin = nn.Dense( self.dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), )
def setup(self): self.dense = nn.Dense(self.config.hidden_size, dtype=self.dtype) classifier_dropout = (self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob) self.dropout = nn.Dropout(classifier_dropout) self.out_proj = nn.Dense(self.config.num_labels, dtype=self.dtype)
def setup(self): self.bert = FlaxBertModule(config=self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.classifier = nn.Dense( self.config.num_labels, dtype=self.dtype, )
def setup(self): self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), dtype=self.dtype, ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, graph, train): dropout = nn.Dropout(rate=self.dropout_rate, deterministic=not train) graph = graph._replace( globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) embedder = jraph.GraphMapFeatures( embed_node_fn=_make_embed(self.latent_dim), embed_edge_fn=_make_embed(self.latent_dim)) graph = embedder(graph) for _ in range(self.num_message_passing_steps): net = jraph.GraphNetwork(update_edge_fn=_make_mlp(self.hidden_dims, dropout=dropout), update_node_fn=_make_mlp(self.hidden_dims, dropout=dropout), update_global_fn=_make_mlp( self.hidden_dims, dropout=dropout)) graph = net(graph) # Map globals to represent the final result decoder = jraph.GraphMapFeatures( embed_global_fn=nn.Dense(self.num_outputs)) graph = decoder(graph) return graph.globals
def __call__(self, inputs_q, inputs_kv, is_training: bool): assert inputs_q.ndim == inputs_kv.ndim == 3 in_ch = inputs_q.shape[-1] assert in_ch % self.num_heads == 0 head_ch = self.head_ch or int(in_ch / self.num_heads) out_ch = self.out_ch or in_ch dense = partial(nn.DenseGeneral, axis=-1, features=(self.num_heads, head_ch), use_bias=self.use_bias, dtype=self.dtype) query = dense(name='queries')(inputs_q) key = dense(name='keys')(inputs_kv) value = dense(name='values')(inputs_kv) query = query / jnp.sqrt(head_ch) attn_weights = jnp.einsum('... q h d, ... k h d -> ... h q k', query, key) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.softmax(attn_weights) if self.talking_heads: attn_weights = TalkingHeadsBlock( num_heads=self.num_heads)(attn_weights) attn_weights = nn.Dropout(rate=self.attn_dropout_rate)( attn_weights, deterministic=not is_training) attn_scores = jnp.einsum('... h q k, ... k h d -> ... q h d', attn_weights, value) output = nn.DenseGeneral(features=out_ch, axis=(-2, -1), use_bias=self.use_bias, dtype=self.dtype)(attn_scores) output = nn.Dropout(rate=self.out_dropout_rate)( output, deterministic=not is_training) return output
def __call__(self, targets, encoded, decoder_mask=None, encoder_decoder_mask=None): """Applies Transformer to decode the targets. Args: targets: target outputs. encoded: encoded input data from encoder [batch, ..., length, mlp_dim]. decoder_mask: decoder self-attention mask encoder_decoder_mask: encoder-decoder attention mask Returns: output of a transformer decoder. """ cfg = self.config assert encoded.ndim == targets.ndim + 1 output_embed = nn.Embed( num_embeddings=cfg.output_vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), name='embed_output') heads = dict() y = targets.astype('int32') if cfg.shift: y = shift_right(y, cfg.bos_token) y = output_embed(y) y = AddPositionEmbs(config=cfg, cache=cfg.decode, name='posembed_output')(y) y = nn.Dropout(rate=cfg.dropout_rate)(y, deterministic=cfg.deterministic) y = y.astype(cfg.dtype) # Target-Input Decoder for lyr in range(cfg.num_layers): y = EncoderDecoderBlock(config=cfg, name=f'encoderdecoderblock_{lyr}')( y, encoded, decoder_mask, encoder_decoder_mask) y = nn.LayerNorm(dtype=cfg.dtype, name='encoderdecoder_norm')(y) heads['output_emb'] = y * (jnp.where(targets > 0, 1, 0).astype( jnp.float32)[Ellipsis, None]) logits = nn.Dense(cfg.output_vocab_size, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, name='logitdense')(y) heads['logits'] = logits if cfg.output_head: return heads[cfg.output_head] else: return heads # Return both output embeddings and logits.
def __call__(self, inputs, train): """Applies Transformer MlpBlock module.""" actual_out_dim = inputs.shape[ -1] if self.out_dim is None else self.out_dim x = nn.Dense(self.mlp_dim, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)(inputs) x = nn.relu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = nn.Dense(actual_out_dim, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output
def setup(self): self.encoder = encoder_registry.get_registered_encoder( self.encoder_name)(**self.encoder_config) if self.apply_mlp: self.mlp = nn.Dense(self.encoder_config.hidden_size, self.dtype) self.dropout = nn.Dropout(self.encoder_config.dropout_rate) self.linear_classifier = nn.Dense(self.vocab_size, dtype=self.dtype)
def setup(self): self.dense = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
def setup(self): self.dropout_layer = nn.Dropout(rate=self.dropout_rate) self.keys_only_mlp_attention = KeysOnlyMlpAttention( hidden_size=self.hidden_size) self.mlp = MLP(hidden_size=self.hidden_size, output_size=self.output_size, output_bias=False, dropout_rate=self.dropout_rate)
def setup(self): self.wav2vec2 = FlaxWav2Vec2Module(self.config, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.final_dropout) self.lm_head = nn.Dense( self.config.vocab_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range, self.dtype), dtype=self.dtype, )
def setup(self): self.cls_token = self.param("cls_token", nn.initializers.zeros, (1, 1, self.config.hidden_size)) self.patch_embeddings = FlaxViTPatchEmbeddings(self.config, dtype=self.dtype) num_patches = self.patch_embeddings.num_patches self.position_embeddings = self.param( "position_embeddings", nn.initializers.zeros, (1, num_patches + 1, self.config.hidden_size) ) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, inputs): x = inputs for size in self.feature_sizes: x = nn.Dense(features=size)(x) x = self.activation(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=self.deterministic)(x) return x
def __call__(self, inputs, deterministic=True): """Applies Transformer MlpBlock module.""" cfg = self.config actual_out_dim = (inputs.shape[-1] if self.out_dim is None else self.out_dim) x = nn.Dense(cfg.mlp_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(inputs) x = nn.elu(x) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic) output = nn.Dense(actual_out_dim, dtype=cfg.dtype, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(x) output = nn.Dropout(rate=cfg.dropout_rate)(output, deterministic=deterministic) return output
def setup(self): self.roberta = FlaxRobertaModule(config=self.config, dtype=self.dtype, add_pooling_layer=False) classifier_dropout = (self.config.classifier_dropout if self.config.classifier_dropout is not None else self.config.hidden_dropout_prob) self.dropout = nn.Dropout(rate=classifier_dropout) self.classifier = nn.Dense(self.config.num_labels, dtype=self.dtype)
def __call__(self, input_qkv): cfg = self.config cfg.max_len % cfg.max_seg_len == 0 bsize = input_qkv.shape[0] features = self.out_features or input_qkv.shape[-1] num_seg = cfg.max_len // cfg.max_seg_len x_sqr = input_qkv.reshape([bsize, num_seg, cfg.max_seg_len, input_qkv.shape[-1]]) q_row_local, key_row_local, value_row_local, head_dim = get_qkv(cfg, x_sqr) local_logits = jnp.einsum('...qhd,...khd->...qhk', q_row_local, key_row_local) row_probs = jax.nn.softmax(local_logits) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') row_probs = dropatt(row_probs, dropout_rng, 1 - cfg.attention_dropout_rate) row_attn_out = jnp.einsum('...qhk,...khd->...qhd', row_probs, value_row_local) key_row = DenseGeneral(features=input_qkv.shape[-1], axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(row_attn_out) key_row = nn.Dropout(rate=cfg.dropout_rate)(key_row, deterministic=cfg.deterministic) key_row = key_row + x_sqr key_row = nn.LayerNorm(dtype=cfg.dtype)(key_row) key_row = DenseGeneral(axis=-1, features=(cfg.num_heads, head_dim), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(key_row) idx_cols = jnp.arange(cfg.max_seg_len) local_mask = nn.make_attention_mask(idx_cols, idx_cols, jnp.less, extra_batch_dims=1) local_mask = jnp.expand_dims(local_mask, axis=-2) * -1e10 local_logits = local_logits + local_mask global_logits = jnp.einsum('bqlhd,bklhd->bqlhk', q_row_local, key_row) idx_rows = jnp.arange(num_seg) global_mask = nn.make_attention_mask(idx_rows, idx_rows, jnp.less_equal) global_mask = global_mask[:, :, jnp.newaxis, jnp.newaxis, :] * -1e10 global_logits = global_logits + global_mask joint_logits = jnp.concatenate((local_logits, global_logits), axis=-1) attn_probs = jax.nn.softmax(joint_logits, axis=-1) local_att, global_att = jnp.split(attn_probs, [cfg.max_seg_len], axis=-1) if not cfg.deterministic and cfg.attention_dropout_rate > 0.: dropout_rng = self.make_rng('dropout') local_att = dropatt(local_att, dropout_rng, 1 - cfg.attention_dropout_rate) local_merged = jnp.einsum('bsqhk,bskhd->bsqhd', local_att, value_row_local) global_merged = jnp.einsum('bqlhv,bvlhd->bqlhd', global_att, row_attn_out) joint_merged = jnp.reshape(local_merged + global_merged, [bsize, cfg.max_len, cfg.num_heads, head_dim]) x = DenseGeneral(features=features, axis=(-2, -1), kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, dtype=cfg.dtype)(joint_merged) return x
def setup(self): self.mlp = nn.Dense(features=self.hidden_dim, kernel_init=self.kernel_init, bias_init=self.bias_init) self.dense = nn.Dense(features=self.input_dim, kernel_init=self.kernel_init, bias_init=self.bias_init) self.dropout = nn.Dropout(self.dropout_rate) self.layer_norm = nn.LayerNorm(epsilon=self.layer_norm_epsilon)
def setup(self): embed_dim = self.config.hidden_size kernel_init = jax.nn.initializers.normal(self.config.initializer_range) self.fc_in = nn.Dense(self.intermediate_size, dtype=self.dtype, kernel_init=kernel_init) self.fc_out = nn.Dense(embed_dim, dtype=self.dtype, kernel_init=kernel_init) self.act = ACT2FN[self.config.activation_function] self.dropout = nn.Dropout(rate=self.config.resid_pdrop)