def setup(self): self.beit = FlaxBeitModule(self.config, add_pooling_layer=False, dtype=self.dtype) # Classifier head self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.lm_head = nn.Dense( self.config.vocab_size, kernel_init=jax.nn.initializers.normal( self.config.initializer_range), dtype=self.dtype, )
def setup(self): self.attention = FlaxBeitAttention(self.config, self.window_size, dtype=self.dtype) self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype) self.output = FlaxBeitOutput(self.config, dtype=self.dtype) self.layernorm_before = nn.LayerNorm( epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate) self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.init_values = self.config.layer_scale_init_value if self.init_values > 0: self.lambda_1 = self.param("lambda_1", ones_with_scale, (self.config.hidden_size), self.init_values) self.lambda_2 = self.param("lambda_2", ones_with_scale, (self.config.hidden_size), self.init_values) else: self.lambda_1 = None self.lambda_2 = None
def setup(self): self.dense = nn.Dense( self.config.hidden_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.decoder = nn.Dense( self.config.vocab_size, dtype=self.dtype, use_bias=False, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), ) self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
def __call__(self, inputs, encoder_mask=None): """Applies Encoder1DBlock module. Args: inputs: input data. encoder_mask: encoder self-attention mask. Returns: output after transformer encoder block. """ cfg = self.config # Attention block. assert inputs.ndim == 3 x = nn.LayerNorm(dtype=cfg.dtype)(inputs) x = nn.SelfAttention( num_heads=cfg.num_heads, dtype=cfg.dtype, qkv_features=cfg.qkv_dim, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init, use_bias=False, broadcast_dropout=False, dropout_rate=cfg.attention_dropout_rate, deterministic=cfg.deterministic)(x, encoder_mask) x = nn.Dropout(rate=cfg.dropout_rate)( x, deterministic=cfg.deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=cfg.dtype)(x) y = MlpBlock(config=cfg)(y) return x + y
def __call__(self, inputs, is_training: bool): x = AddAbsPosEmbed()(inputs) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not is_training) for _ in range(self.num_layers): x = EncoderBlock(num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, activation_fn=self.activation_fn, dtype=self.dtype)(x, is_training=is_training) output = nn.LayerNorm(dtype=self.dtype)(x) return output
def __call__(self, x): # TODO(lbeyer): condition on GAP(x) n, _, d = x.shape probe = self.param('probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform())(probe, x) # TODO(lbeyer): dropout on head? y = nn.LayerNorm()(x) x = x + MlpBlock(mlp_dim=self.mlp_dim)(y) return x[:, 0]
def __call__(self, inputs: jnp.ndarray, *, deterministic: Optional[bool] = None): """Applies Encoder1Dlock module.""" assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}" x = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_0")(inputs) x = nn.MultiHeadDotProductAttention( dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), broadcast_dropout=False, deterministic=deterministic, name="MultiHeadDotProductAttention_1", num_heads=self.num_heads, dropout_rate=self.attention_dropout_rate)(x, x) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic) x = x + inputs # MLP block. y = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_2")(x) y = self.mlp_class(name="MlpBlock_3")(y, deterministic=deterministic) return x + y
def setup(self): self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype) self.ffn = nn.Dense( self.config.intermediate_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.activation = ACT2FN[self.config.hidden_act] self.ffn_output = nn.Dense( self.config.hidden_size, kernel_init=jax.nn.initializers.normal(self.config.initializer_range), dtype=self.dtype, ) self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, inputs, *, train): del train x = nn.Conv(self.hidden_dim, self.patches.size, strides=self.patches.size, name='stem')(inputs) x = einops.rearrange(x, 'n h w c -> n (h w) c') for _ in range(self.num_blocks): if _ % 2 != 0: x = models_mixer_ct_cat_tc.MixerBlock_ct_cat_tc(self.tokens_mlp_dim, self.channels_mlp_dim)(x) if _ % 2 == 0: x = models_mixer_tc_cat_ct.MixerBlock_tc_cat_ct(self.tokens_mlp_dim, self.channels_mlp_dim)(x) x = nn.LayerNorm(name='pre_head_layer_norm')(x) x = jnp.mean(x, axis=1) return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros, name='head')(x)
def setup(self): self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1 self.out_conv_dim = self.config.conv_dim[self.layer_id] self.conv = nn.Conv( features=self.config.conv_dim[self.layer_id], kernel_size=self.config.conv_kernel[self.layer_id], strides=(self.config.conv_stride[self.layer_id],), use_bias=self.config.conv_bias, kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype), padding="VALID", dtype=self.dtype, ) self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.activation = ACT2FN[self.config.feat_extract_activation]
def __call__(self, images: jnp.ndarray, train: Optional[bool] = None): train = nn.module.merge_param("train", self.train, train) transformer = self.transformer or {} # Convert images to patches. x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid) # Add "class" token if necessary. n, _, c = x.shape if self.classifier == "token": cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) # Encode tokens. x, extra_info = BatchEnsembleEncoder( train=train, name="BatchEnsembleTransformer", **transformer)( x) # Reduce tokens to a single vector representation. if self.classifier == "token": # Take the first token's output as representation as in BERT. x = x[:, 0] elif self.classifier == "gap": # Average all tokens. x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1))) # (1,) or (1, 2) elif self.classifier == "map": probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c)) probe = jnp.tile(probe, [n, 1, 1]) attention = nn.MultiHeadDotProductAttention( deterministic=not train, num_heads=transformer.get("attention", {}).get("num_heads", 1), kernel_init=nn.initializers.xavier_uniform()) x = attention(inputs_q=probe, inputs_kv=x) y = nn.LayerNorm()(x) y = patch_transformer_lib.MlpBlock( mlp_dim=transformer["mlp_dim"], dropout_rate=0, deterministic=not train)(y) x = (x + y)[:, 0] else: raise ValueError(f"Unknown classifier: {self.classifier}") if self.representation_size is None: x = identity.IdentityLayer(name="pre_logits")(x) else: x = nn.Dense(self.representation_size, name="pre_logits")(x) x = nn.tanh(x) x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init, name="head")(x) return x, extra_info
def __call__(self, inputs, inputs_positions=None, inputs_segmentation=None): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. inputs_segmentation: input segmentation info for packed examples. Returns: output of a transformer encoder. """ cfg = self.config assert inputs.ndim == 2 # (batch, len) # Padding Masks src_padding_mask = (inputs > 0)[..., None] # Input Embedding if self.shared_embedding is None: input_embed = nn.Embed( num_embeddings=cfg.vocab_size, features=cfg.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0)) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs(config=cfg, name='posembed_input')( x, inputs_positions=inputs_positions) x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) # Input Encoder for lyr in range(cfg.num_layers): x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')( x, padding_mask=src_padding_mask, inputs_segmentation=inputs_segmentation) encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return encoded
def setup(self): self.embed_dim = self.config.hidden_size self.wte = nn.Embed( self.config.vocab_size, self.embed_dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) self.wpe = nn.Embed( self.config.max_position_embeddings, self.embed_dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) self.dropout = nn.Dropout(rate=self.config.embd_pdrop) self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype) self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
def __call__(self, inputs, is_training: bool): x = PatchEmbedBlock(patch_shape=self.patch_shape, embed_dim=self.embed_dim, use_bias=True, dtype=self.dtype)(inputs) for _ in range(self.num_layers): x = MixerBlock(tokens_expand_ratio=self.tokens_expand_ratio, channels_expand_ratio=self.channels_expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(x, is_training=is_training) x = nn.LayerNorm(dtype=self.dtype)(x) x = jnp.mean(x, axis=1) output = nn.Dense(features=self.num_classes, dtype=self.dtype)(x) return output
def __call__(self, inputs): config = self.config assert inputs.ndim == 3 # (batch, len, embed) y = AddPositionEmbs(config=config)(inputs) y = nn.Dropout(rate=config.dropout_rate)( y, deterministic=config.deterministic) assert issubclass(type(self.transformer_layer), partial) for l in range(config.num_layers): y = self.transformer_layer(name='transformer_layer_%d' % l)(y) y = nn.LayerNorm(dtype=config.dtype)(y) out = nn.Dense(self.pred_dim, dtype=config.dtype, kernel_init=config.kernel_init, bias_init=config.bias_init)(y) return out
def setup(self): self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) if not self.config.sinusoidal_pos_embds: self.position_embeddings = nn.Embed( self.config.max_position_embeddings, self.config.dim, embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) else: self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim) self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.dropout)
def setup(self): self.word_embeddings = nn.Embed( self.config.vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), ) self.token_type_embeddings = nn.Embed( self.config.type_vocab_size, self.config.hidden_size, embedding_init=jax.nn.initializers.normal( stddev=self.config.initializer_range), ) self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype) self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
def __call__(self, inputs): cfg = self.config assert inputs.ndim == 3 dense = partial(nn.DenseGeneral, axis=-1, features=(cfg.num_heads, cfg.dim_head), use_bias=False, kernel_init=cfg.kernel_init, precision=cfg.precision) query, key, value = (dense(dtype=cfg.dtype)(inputs), dense(dtype=cfg.dtype)(inputs), dense(dtype=cfg.dtype)(inputs)) query = query / jnp.sqrt(cfg.dim_head).astype(cfg.dtype) attn_weights = jnp.einsum('b q h d, b k h d -> b h q k', query, key, precision=cfg.precision) attn_weights = nn.softmax(attn_weights).astype(cfg.dtype) if cfg.shared_theta: attn_weights = self.theta_transform(attn_weights) else: attn_weights = ThetaTransform(config=cfg)(attn_weights) attn_weights = nn.LayerNorm()(attn_weights) out = jnp.einsum('b h q k, b q h d -> b k h d', attn_weights, value, precision=cfg.precision) if (cfg.num_heads * cfg.dim_head) != cfg.emb_dim: out = nn.DenseGeneral(features=cfg.emb_dim, axis=(-2, -1), dtype=cfg.dtype, precision=cfg.precision, kernel_init=cfg.kernel_init, bias_init=cfg.bias_init)(out) else: out = rearrange(out, 'b k h d -> b k (h d)') return out
def __call__(self, intermediate_output, attention_output, deterministic: bool = True): hidden_states = nn.Dense( attention_output.shape[-1], kernel_init=jax.nn.initializers.normal(self.kernel_init_scale, self.dtype), name="dense", dtype=self.dtype, )(intermediate_output) hidden_states = nn.Dropout(rate=self.dropout_rate)( hidden_states, deterministic=deterministic) hidden_states = nn.LayerNorm(name="layer_norm", dtype=self.dtype)(hidden_states + attention_output) return hidden_states
def setup(self): self.embedder = embedding.DictEmbed({ 'token_ids': embedding.Embed( num_embeddings=self.vocab_size, embedding_dim=self.hidden_size, dtype=self.dtype, embedding_init=self.kernel_init, ), 'position_ids': embedding.Embed( num_embeddings=self.max_positions, embedding_dim=self.hidden_size, dtype=self.dtype, embedding_init=self.kernel_init, ), 'segment_ids': embedding.Embed( num_embeddings=self.num_segments, embedding_dim=self.hidden_size, dtype=self.dtype, embedding_init=self.kernel_init, ) }) self.embeddings_layer_norm = nn.LayerNorm( epsilon=self.layer_norm_epsilon) self.embeddings_dropout = nn.Dropout(rate=self.dropout_rate) self.encoder = transformer.TransformerBlock( num_layers=self.num_layers, model_dim=self.hidden_size, intermediate_dim=self.intermediate_dim, num_heads=self.num_attention_heads, dropout_rate=self.dropout_rate, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init, layer_norm_epsilon=self.layer_norm_epsilon, ) self.mention_projector = nn.Dense( features=self.mention_encoding_dim, dtype=self.dtype, )
def __call__(self, x, mask): n, d, h, dh, dim = *x.shape, self.heads, self.dim_head, self.dim if d != dim: x = nn.Dense(features=dim)(x) cls_token = self.param("cls", self.cls_init, (1, x.shape[-1])) to_norm_out = nn.LayerNorm() sincos = fixed_pos_embedding(n, self.dim_head) x = np.concatenate((cls_token, x), axis=0) for attn, ff in self.layers: x = attn(x, pos_emb=sincos, mask=mask) + x x = ff(x) + x x = to_norm_out(x) return x
def __call__(self, inputs: jnp.ndarray, inputs_positions: Optional[jnp.ndarray] = None, train: Optional[bool] = None): """Applies Transformer model on the inputs.""" train = nn.module.merge_param("train", self.train, train) dtype = self.dtype or inputs.dtype assert inputs.ndim == 3 # (batch, len, emb) # List indicating which MLPs to substitute with BatchEnsemble MLPs. be_layers = self.be_layers if be_layers is None: be_layers = list(range(1, self.num_layers, 2)) x = patch_transformer_lib.AddPositionEmbs(name="posembed_input")( inputs, inputs_positions) x = nn.Dropout(rate=self.dropout_rate, deterministic=not self.train)(x) be_params = dict(ens_size=self.ens_size, random_sign_init=self.random_sign_init) mlp_params = dict(dtype=dtype, deterministic=not self.train, name="mlp") mlp_params_dense = dict(dropout_rate=self.dropout_rate, mlp_dim=self.mlp_dim) mlp_dense = functools.partial(patch_transformer_lib.MlpBlock, **mlp_params, **mlp_params_dense) be_block = functools.partial(BatchEnsembleMlpBlock, **mlp_params, **mlp_params_dense, **be_params) extra_info = dict() for lyr in range(self.num_layers): encoder_block = functools.partial( patch_transformer_lib.Encoder1DBlock, num_heads=self.num_heads, dtype=dtype, dropout_rate=self.dropout_rate, deterministic=not self.train, attention_dropout_rate=self.attention_dropout_rate, name=f"encoderblock_{lyr}") if lyr in be_layers: x = encoder_block(mlp_class=be_block)(x) else: x = encoder_block(mlp_class=mlp_dense)(x) encoded = nn.LayerNorm(name="encoder_norm")(x) return encoded, extra_info
def __call__(self, z, a, key): kernel_initializer = jax.nn.initializers.glorot_uniform() x = nn.Dense(features=self.layer_width, kernel_init=kernel_initializer)(jnp.concatenate([z, a])) x = nn.LayerNorm()(x) x = nn.relu(x) mu = nn.Dense(features=self.embedding_dim, kernel_init=kernel_initializer)(x) if self.probabilistic: sigma = nn.Dense(features=self.embedding_dim, kernel_init=kernel_initializer)(x) sigma = nn.sigmoid(sigma) sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma eps = jax.random.normal(key, shape=sigma.shape) sample = mu + sigma * eps else: sigma = jnp.zeros(self.embedding_dim) sample = mu return DynamicsModelType(mu, sigma, sample)
def setup(self): self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype) self.vocab_transform = nn.Dense( self.config.dim, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), ) self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype) if self.config.tie_word_embeddings: self.vocab_projector = FlaxDistilBertLMDecoder( self.config, dtype=self.dtype, ) else: self.vocab_projector = nn.Dense( self.config.vocab_size, dtype=self.dtype, kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range), )
def __call__(self, inputs, is_training: bool): inputs = zero_pad_and_reshape(inputs) x = CvTSelfAttentionBlock(num_heads=self.num_heads, kernel_size=self.kernel_size, use_bias=self.use_bias, bn_momentum=self.bn_momentum, bn_epsilon=self.bn_epsilon, dtype=self.dtype)(inputs, is_training=is_training) x = x + rearrange(inputs, 'b h w d -> b (h w) d') y = nn.LayerNorm(dtype=self.dtype)(x) y = FFBlock(expand_ratio=self.expand_ratio, activation_fn=self.activation_fn, dtype=self.dtype)(y, is_training=is_training) output = x + y return output
def __call__(self, input_ids, type_ids, deterministic = False): """Applies EmbeddingLayer module. Args: input_ids: Batch of tokenized inputs of shape <int>[BATCH_SIZE, MAX_SEQ_LENGTH]. type_ids: Ids partitioning input into different types. deterministic: Whether or not to apply dropout to output embeddings. Returns: Embedded tokens of shape <float>[BATCH_SIZE, MAX_SEQ_LENGTH, EMB_DIM]. """ word_embeddings = nn.Embed( num_embeddings=self.config.vocab_size, features=self.config.d_emb, embedding_init=default_kernel_init, name="word")( input_ids) position_embeddings = PositionalEncoding( max_seq_length=self.config.max_seq_length, posemb_init=default_kernel_init, name="position")( word_embeddings) type_embeddings = nn.Embed( num_embeddings=self.config.type_vocab_size, features=self.config.d_emb, embedding_init=default_kernel_init, name="type")( type_ids) embeddings = word_embeddings + position_embeddings + type_embeddings embeddings = nn.LayerNorm( epsilon=LAYER_NORM_EPSILON, name="layer_norm")( embeddings) embeddings = nn.Dense( self.config.d_model, name="hidden_mapping_in")( embeddings) return nn.Dropout(rate=self.config.dropout_rate)( embeddings, deterministic=deterministic)
def __call__(self, inputs, encoder_mask=None): """Applies Transformer model on the inputs. Args: inputs: input data inputs_positions: input subsequence positions for packed examples. encoder_mask: decoder self-attention mask. Returns: output of a transformer encoder. """ cfg = self.config x = inputs # Input Encoder for lyr in range(cfg.num_layers): x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')(x, encoder_mask) encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x) return encoded
def __call__(self, node_feature_embeddings, node_position_embeddings, adjacency_mat, qstar): """Applies the TransformerEncoder module. Args: node_feature_embeddings: Embeddings representing nodes. node_position_embeddings: Embeddings representing node positions. adjacency_mat: Adjacency matrix over the nodes. Not used for now. qstar: float tensor of shape (num_of_nodes,) The optimal q weighting over the nodes of the graph, from the subgraph selection module. Returns: encoded: Encoded nodes, with extra Graph node at the end. """ cfg = self.config x = node_feature_embeddings + node_position_embeddings # Add average weight to graph node for scale qstar = jnp.append(qstar, jnp.mean(qstar)) # Multiply embeddings by node weights. => learn the agent model. x = x * qstar[Ellipsis, None] x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=cfg.deterministic) x = x.astype(cfg.dtype) # TODO(gnegiar): Plot x here to check # Keep nodes with positive weights mask1d = qstar != 0 encoder_mask = nn.attention.make_attention_mask(mask1d, mask1d) # Input Encoder for lyr in range(cfg.num_layers): x = Encoder1DBlock(config=cfg, name=f"encoderblock_{lyr}")(x, encoder_mask) x = x * mask1d[Ellipsis, None] # TODO(gnegiar): Also plot x here # Possibly plot gradient norms per encoder layer # Plot attention weights? encoded = nn.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x) return encoded
def __call__(self, inputs, is_training: bool): x = PatchEmbedBlock(patch_shape=self.patch_shape, embed_dim=self.embed_dim, dtype=self.dtype)(inputs) x = Encoder(num_layers=self.num_layers, num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, stoch_depth_rate=self.stoch_depth_rate, layerscale_eps=self.layerscale_eps, activation_fn=self.activation_fn)(x, is_training=is_training) b = x.shape[0] cls_shape = (1, 1, self.embed_dim) cls_token = self.param('cls', nn.initializers.zeros, cls_shape) cls_token = jnp.tile(cls_token, [b, 1, 1]) for _ in range(self.num_layers_token_only): cls_token = CAEncoderBlock( num_heads=self.num_heads, expand_ratio=self.expand_ratio, attn_dropout_rate=self.attn_dropout_rate, dropout_rate=self.dropout_rate, stoch_depth_rate=self.stoch_depth_rate, layerscale_eps=self.layerscale_eps, activation_fn=self.activation_fn, dtype=self.dtype)(x, cls_token, is_training=is_training) x = jnp.concatenate([cls_token, x], axis=1) x = nn.LayerNorm(dtype=self.dtype)(x) cls_token = x[:, 0] output = nn.Dense(features=self.num_classes, use_bias=True, dtype=self.dtype, kernel_init=nn.initializers.zeros)(cls_token) return output
def setup(self): self.dropout_layer = nn.Dropout(rate=self.config.dropout) embed_dim = self.config.d_model self.padding_idx = self.config.pad_token_id self.max_target_positions = self.config.max_position_embeddings self.embed_scale = math.sqrt( self.config.d_model) if self.config.scale_embedding else 1.0 self.embed_tokens = nn.Embed( self.config.vocab_size, embed_dim, embedding_init=jax.nn.initializers.normal(self.config.init_std), ) # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2 # and adjust num_embeddings appropriately. Other models don't have this hack self.offset = 2 self.embed_positions = create_sinusoidal_positions( self.config.max_position_embeddings + self.offset, embed_dim) self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype) self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)