def setup(self):
        self.beit = FlaxBeitModule(self.config,
                                   add_pooling_layer=False,
                                   dtype=self.dtype)

        # Classifier head
        self.layernorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,
                                      dtype=self.dtype)
        self.lm_head = nn.Dense(
            self.config.vocab_size,
            kernel_init=jax.nn.initializers.normal(
                self.config.initializer_range),
            dtype=self.dtype,
        )
    def setup(self):
        self.attention = FlaxBeitAttention(self.config,
                                           self.window_size,
                                           dtype=self.dtype)
        self.intermediate = FlaxBeitIntermediate(self.config, dtype=self.dtype)
        self.output = FlaxBeitOutput(self.config, dtype=self.dtype)
        self.layernorm_before = nn.LayerNorm(
            epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.drop_path = FlaxBeitDropPath(rate=self.drop_path_rate)
        self.layernorm_after = nn.LayerNorm(epsilon=self.config.layer_norm_eps,
                                            dtype=self.dtype)

        self.init_values = self.config.layer_scale_init_value
        if self.init_values > 0:
            self.lambda_1 = self.param("lambda_1", ones_with_scale,
                                       (self.config.hidden_size),
                                       self.init_values)
            self.lambda_2 = self.param("lambda_2", ones_with_scale,
                                       (self.config.hidden_size),
                                       self.init_values)
        else:
            self.lambda_1 = None
            self.lambda_2 = None
Exemple #3
0
 def setup(self):
     self.dense = nn.Dense(
         self.config.hidden_size,
         dtype=self.dtype,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
     )
     self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
     self.decoder = nn.Dense(
         self.config.vocab_size,
         dtype=self.dtype,
         use_bias=False,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
     )
     self.bias = self.param("bias", self.bias_init, (self.config.vocab_size,))
  def __call__(self,
               inputs,
               encoder_mask=None):
    """Applies Encoder1DBlock module.

    Args:
      inputs: input data.
      encoder_mask: encoder self-attention mask.

    Returns:
      output after transformer encoder block.
    """
    cfg = self.config

    # Attention block.
    assert inputs.ndim == 3
    x = nn.LayerNorm(dtype=cfg.dtype)(inputs)
    x = nn.SelfAttention(
        num_heads=cfg.num_heads,
        dtype=cfg.dtype,
        qkv_features=cfg.qkv_dim,
        kernel_init=cfg.kernel_init,
        bias_init=cfg.bias_init,
        use_bias=False,
        broadcast_dropout=False,
        dropout_rate=cfg.attention_dropout_rate,
        deterministic=cfg.deterministic)(x, encoder_mask)

    x = nn.Dropout(rate=cfg.dropout_rate)(
        x, deterministic=cfg.deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=cfg.dtype)(x)
    y = MlpBlock(config=cfg)(y)

    return x + y
    def __call__(self, inputs, is_training: bool):
        x = AddAbsPosEmbed()(inputs)
        x = nn.Dropout(rate=self.dropout_rate)(x,
                                               deterministic=not is_training)

        for _ in range(self.num_layers):
            x = EncoderBlock(num_heads=self.num_heads,
                             expand_ratio=self.expand_ratio,
                             attn_dropout_rate=self.attn_dropout_rate,
                             dropout_rate=self.dropout_rate,
                             activation_fn=self.activation_fn,
                             dtype=self.dtype)(x, is_training=is_training)

        output = nn.LayerNorm(dtype=self.dtype)(x)
        return output
Exemple #6
0
    def __call__(self, x):
        # TODO(lbeyer): condition on GAP(x)
        n, _, d = x.shape
        probe = self.param('probe', nn.initializers.xavier_uniform(),
                           (1, 1, d), x.dtype)
        probe = jnp.tile(probe, [n, 1, 1])

        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform())(probe, x)

        # TODO(lbeyer): dropout on head?
        y = nn.LayerNorm()(x)
        x = x + MlpBlock(mlp_dim=self.mlp_dim)(y)
        return x[:, 0]
Exemple #7
0
  def __call__(self,
               inputs: jnp.ndarray,
               *,
               deterministic: Optional[bool] = None):
    """Applies Encoder1Dlock module."""
    assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"

    x = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_0")(inputs)
    x = nn.MultiHeadDotProductAttention(
        dtype=self.dtype,
        kernel_init=nn.initializers.xavier_uniform(),
        broadcast_dropout=False,
        deterministic=deterministic,
        name="MultiHeadDotProductAttention_1",
        num_heads=self.num_heads,
        dropout_rate=self.attention_dropout_rate)(x, x)
    x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
    x = x + inputs

    # MLP block.
    y = nn.LayerNorm(dtype=self.dtype, name="LayerNorm_2")(x)
    y = self.mlp_class(name="MlpBlock_3")(y, deterministic=deterministic)

    return x + y
 def setup(self):
     self.attention = FlaxAlbertSelfAttention(self.config, dtype=self.dtype)
     self.ffn = nn.Dense(
         self.config.intermediate_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
         dtype=self.dtype,
     )
     self.activation = ACT2FN[self.config.hidden_act]
     self.ffn_output = nn.Dense(
         self.config.hidden_size,
         kernel_init=jax.nn.initializers.normal(self.config.initializer_range),
         dtype=self.dtype,
     )
     self.full_layer_layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Exemple #9
0
 def __call__(self, inputs, *, train):
   del train
   x = nn.Conv(self.hidden_dim, self.patches.size,
               strides=self.patches.size, name='stem')(inputs)
   x = einops.rearrange(x, 'n h w c -> n (h w) c')
   
   for _ in range(self.num_blocks):
       if _ % 2 != 0:
           x = models_mixer_ct_cat_tc.MixerBlock_ct_cat_tc(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
       if _ % 2 == 0: 
           x = models_mixer_tc_cat_ct.MixerBlock_tc_cat_ct(self.tokens_mlp_dim, self.channels_mlp_dim)(x)
   x = nn.LayerNorm(name='pre_head_layer_norm')(x)
   x = jnp.mean(x, axis=1)
   return nn.Dense(self.num_classes, kernel_init=nn.initializers.zeros,
                   name='head')(x)
Exemple #10
0
    def setup(self):
        self.in_conv_dim = self.config.conv_dim[self.layer_id] if self.layer_id > 0 else 1
        self.out_conv_dim = self.config.conv_dim[self.layer_id]

        self.conv = nn.Conv(
            features=self.config.conv_dim[self.layer_id],
            kernel_size=self.config.conv_kernel[self.layer_id],
            strides=(self.config.conv_stride[self.layer_id],),
            use_bias=self.config.conv_bias,
            kernel_init=jax.nn.initializers.he_normal(dtype=self.dtype),
            padding="VALID",
            dtype=self.dtype,
        )
        self.layer_norm = nn.LayerNorm(epsilon=self.config.layer_norm_eps, dtype=self.dtype)
        self.activation = ACT2FN[self.config.feat_extract_activation]
  def __call__(self, images: jnp.ndarray, train: Optional[bool] = None):
    train = nn.module.merge_param("train", self.train, train)
    transformer = self.transformer or {}
    # Convert images to patches.
    x = self.patches(images, self.hidden_size, self.patch_size, self.patch_grid)
    # Add "class" token if necessary.
    n, _, c = x.shape
    if self.classifier == "token":
      cls = self.param("cls", nn.initializers.zeros, (1, 1, self.hidden_size))
      cls = jnp.tile(cls, [n, 1, 1])
      x = jnp.concatenate([cls, x], axis=1)
    # Encode tokens.
    x, extra_info = BatchEnsembleEncoder(
        train=train, name="BatchEnsembleTransformer", **transformer)(
            x)
    # Reduce tokens to a single vector representation.
    if self.classifier == "token":
      # Take the first token's output as representation as in BERT.
      x = x[:, 0]
    elif self.classifier == "gap":
      # Average all tokens.
      x = jnp.mean(x, axis=tuple(range(1, x.ndim - 1)))  # (1,) or (1, 2)
    elif self.classifier == "map":
      probe = self.param("probe", nn.initializers.xavier_uniform(), (1, 1, c))
      probe = jnp.tile(probe, [n, 1, 1])
      attention = nn.MultiHeadDotProductAttention(
          deterministic=not train,
          num_heads=transformer.get("attention", {}).get("num_heads", 1),
          kernel_init=nn.initializers.xavier_uniform())
      x = attention(inputs_q=probe, inputs_kv=x)
      y = nn.LayerNorm()(x)
      y = patch_transformer_lib.MlpBlock(
          mlp_dim=transformer["mlp_dim"],
          dropout_rate=0,
          deterministic=not train)(y)
      x = (x + y)[:, 0]
    else:
      raise ValueError(f"Unknown classifier: {self.classifier}")

    if self.representation_size is None:
      x = identity.IdentityLayer(name="pre_logits")(x)
    else:
      x = nn.Dense(self.representation_size, name="pre_logits")(x)
      x = nn.tanh(x)

    x = nn.Dense(self.num_classes, kernel_init=self.head_kernel_init,
                 name="head")(x)
    return x, extra_info
Exemple #12
0
    def __call__(self,
                 inputs,
                 inputs_positions=None,
                 inputs_segmentation=None):
        """Applies Transformer model on the inputs.

    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      inputs_segmentation: input segmentation info for packed examples.

    Returns:
      output of a transformer encoder.
    """
        cfg = self.config

        assert inputs.ndim == 2  # (batch, len)

        # Padding Masks
        src_padding_mask = (inputs > 0)[..., None]

        # Input Embedding
        if self.shared_embedding is None:
            input_embed = nn.Embed(
                num_embeddings=cfg.vocab_size,
                features=cfg.emb_dim,
                embedding_init=nn.initializers.normal(stddev=1.0))
        else:
            input_embed = self.shared_embedding
        x = inputs.astype('int32')
        x = input_embed(x)
        x = AddPositionEmbs(config=cfg, name='posembed_input')(
            x, inputs_positions=inputs_positions)
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        x = x.astype(cfg.dtype)

        # Input Encoder
        for lyr in range(cfg.num_layers):
            x = Encoder1DBlock(config=cfg, name=f'encoderblock_{lyr}')(
                x,
                padding_mask=src_padding_mask,
                inputs_segmentation=inputs_segmentation)

        encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

        return encoded
    def setup(self):
        self.embed_dim = self.config.hidden_size

        self.wte = nn.Embed(
            self.config.vocab_size,
            self.embed_dim,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
        self.wpe = nn.Embed(
            self.config.max_position_embeddings,
            self.embed_dim,
            embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
        )
        self.dropout = nn.Dropout(rate=self.config.embd_pdrop)
        self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)
        self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)
    def __call__(self, inputs, is_training: bool):
        x = PatchEmbedBlock(patch_shape=self.patch_shape,
                            embed_dim=self.embed_dim,
                            use_bias=True,
                            dtype=self.dtype)(inputs)

        for _ in range(self.num_layers):
            x = MixerBlock(tokens_expand_ratio=self.tokens_expand_ratio,
                           channels_expand_ratio=self.channels_expand_ratio,
                           activation_fn=self.activation_fn,
                           dtype=self.dtype)(x, is_training=is_training)

        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = jnp.mean(x, axis=1)
        output = nn.Dense(features=self.num_classes, dtype=self.dtype)(x)
        return output
Exemple #15
0
    def __call__(self, inputs):
        config = self.config
        assert inputs.ndim == 3  # (batch, len, embed)
        y = AddPositionEmbs(config=config)(inputs)
        y = nn.Dropout(rate=config.dropout_rate)(
            y, deterministic=config.deterministic)
        assert issubclass(type(self.transformer_layer), partial)
        for l in range(config.num_layers):
            y = self.transformer_layer(name='transformer_layer_%d' % l)(y)
        y = nn.LayerNorm(dtype=config.dtype)(y)

        out = nn.Dense(self.pred_dim,
                       dtype=config.dtype,
                       kernel_init=config.kernel_init,
                       bias_init=config.bias_init)(y)
        return out
Exemple #16
0
 def setup(self):
     self.word_embeddings = nn.Embed(
         self.config.vocab_size,
         self.config.dim,
         embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
     )
     if not self.config.sinusoidal_pos_embds:
         self.position_embeddings = nn.Embed(
             self.config.max_position_embeddings,
             self.config.dim,
             embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
         )
     else:
         self.pos_encoding = positional_encoding(self.config.max_position_embeddings, self.config.dim)
     self.LayerNorm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.dropout)
 def setup(self):
     self.word_embeddings = nn.Embed(
         self.config.vocab_size,
         self.config.hidden_size,
         embedding_init=jax.nn.initializers.normal(
             stddev=self.config.initializer_range),
     )
     self.token_type_embeddings = nn.Embed(
         self.config.type_vocab_size,
         self.config.hidden_size,
         embedding_init=jax.nn.initializers.normal(
             stddev=self.config.initializer_range),
     )
     self.LayerNorm = nn.LayerNorm(epsilon=self.config.layer_norm_eps,
                                   dtype=self.dtype)
     self.dropout = nn.Dropout(rate=self.config.hidden_dropout_prob)
Exemple #18
0
    def __call__(self, inputs):
        cfg = self.config
        assert inputs.ndim == 3

        dense = partial(nn.DenseGeneral,
                        axis=-1,
                        features=(cfg.num_heads, cfg.dim_head),
                        use_bias=False,
                        kernel_init=cfg.kernel_init,
                        precision=cfg.precision)

        query, key, value = (dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs),
                             dense(dtype=cfg.dtype)(inputs))

        query = query / jnp.sqrt(cfg.dim_head).astype(cfg.dtype)

        attn_weights = jnp.einsum('b q h d, b k h d -> b h q k',
                                  query,
                                  key,
                                  precision=cfg.precision)
        attn_weights = nn.softmax(attn_weights).astype(cfg.dtype)

        if cfg.shared_theta:
            attn_weights = self.theta_transform(attn_weights)
        else:
            attn_weights = ThetaTransform(config=cfg)(attn_weights)

        attn_weights = nn.LayerNorm()(attn_weights)

        out = jnp.einsum('b h q k, b q h d -> b k h d',
                         attn_weights,
                         value,
                         precision=cfg.precision)

        if (cfg.num_heads * cfg.dim_head) != cfg.emb_dim:
            out = nn.DenseGeneral(features=cfg.emb_dim,
                                  axis=(-2, -1),
                                  dtype=cfg.dtype,
                                  precision=cfg.precision,
                                  kernel_init=cfg.kernel_init,
                                  bias_init=cfg.bias_init)(out)
        else:
            out = rearrange(out, 'b k h d -> b k (h d)')

        return out
Exemple #19
0
 def __call__(self,
              intermediate_output,
              attention_output,
              deterministic: bool = True):
     hidden_states = nn.Dense(
         attention_output.shape[-1],
         kernel_init=jax.nn.initializers.normal(self.kernel_init_scale,
                                                self.dtype),
         name="dense",
         dtype=self.dtype,
     )(intermediate_output)
     hidden_states = nn.Dropout(rate=self.dropout_rate)(
         hidden_states, deterministic=deterministic)
     hidden_states = nn.LayerNorm(name="layer_norm",
                                  dtype=self.dtype)(hidden_states +
                                                    attention_output)
     return hidden_states
Exemple #20
0
    def setup(self):

        self.embedder = embedding.DictEmbed({
            'token_ids':
            embedding.Embed(
                num_embeddings=self.vocab_size,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'position_ids':
            embedding.Embed(
                num_embeddings=self.max_positions,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            ),
            'segment_ids':
            embedding.Embed(
                num_embeddings=self.num_segments,
                embedding_dim=self.hidden_size,
                dtype=self.dtype,
                embedding_init=self.kernel_init,
            )
        })

        self.embeddings_layer_norm = nn.LayerNorm(
            epsilon=self.layer_norm_epsilon)
        self.embeddings_dropout = nn.Dropout(rate=self.dropout_rate)

        self.encoder = transformer.TransformerBlock(
            num_layers=self.num_layers,
            model_dim=self.hidden_size,
            intermediate_dim=self.intermediate_dim,
            num_heads=self.num_attention_heads,
            dropout_rate=self.dropout_rate,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            layer_norm_epsilon=self.layer_norm_epsilon,
        )

        self.mention_projector = nn.Dense(
            features=self.mention_encoding_dim,
            dtype=self.dtype,
        )
Exemple #21
0
    def __call__(self, x, mask):
        n, d, h, dh, dim = *x.shape, self.heads, self.dim_head, self.dim

        if d != dim:
            x = nn.Dense(features=dim)(x)

        cls_token = self.param("cls", self.cls_init, (1, x.shape[-1]))
        to_norm_out = nn.LayerNorm()

        sincos = fixed_pos_embedding(n, self.dim_head)

        x = np.concatenate((cls_token, x), axis=0)

        for attn, ff in self.layers:
            x = attn(x, pos_emb=sincos, mask=mask) + x
            x = ff(x) + x

        x = to_norm_out(x)
        return x
  def __call__(self,
               inputs: jnp.ndarray,
               inputs_positions: Optional[jnp.ndarray] = None,
               train: Optional[bool] = None):
    """Applies Transformer model on the inputs."""
    train = nn.module.merge_param("train", self.train, train)
    dtype = self.dtype or inputs.dtype
    assert inputs.ndim == 3  # (batch, len, emb)
    # List indicating which MLPs to substitute with BatchEnsemble MLPs.
    be_layers = self.be_layers
    if be_layers is None:
      be_layers = list(range(1, self.num_layers, 2))

    x = patch_transformer_lib.AddPositionEmbs(name="posembed_input")(
        inputs, inputs_positions)
    x = nn.Dropout(rate=self.dropout_rate, deterministic=not self.train)(x)

    be_params = dict(ens_size=self.ens_size,
                     random_sign_init=self.random_sign_init)
    mlp_params = dict(dtype=dtype, deterministic=not self.train, name="mlp")
    mlp_params_dense = dict(dropout_rate=self.dropout_rate,
                            mlp_dim=self.mlp_dim)
    mlp_dense = functools.partial(patch_transformer_lib.MlpBlock, **mlp_params,
                                  **mlp_params_dense)
    be_block = functools.partial(BatchEnsembleMlpBlock, **mlp_params,
                                 **mlp_params_dense, **be_params)
    extra_info = dict()
    for lyr in range(self.num_layers):
      encoder_block = functools.partial(
          patch_transformer_lib.Encoder1DBlock,
          num_heads=self.num_heads,
          dtype=dtype,
          dropout_rate=self.dropout_rate,
          deterministic=not self.train,
          attention_dropout_rate=self.attention_dropout_rate,
          name=f"encoderblock_{lyr}")
      if lyr in be_layers:
        x = encoder_block(mlp_class=be_block)(x)
      else:
        x = encoder_block(mlp_class=mlp_dense)(x)
    encoded = nn.LayerNorm(name="encoder_norm")(x)

    return encoded, extra_info
Exemple #23
0
 def __call__(self, z, a, key):
     kernel_initializer = jax.nn.initializers.glorot_uniform()
     x = nn.Dense(features=self.layer_width,
                  kernel_init=kernel_initializer)(jnp.concatenate([z, a]))
     x = nn.LayerNorm()(x)
     x = nn.relu(x)
     mu = nn.Dense(features=self.embedding_dim,
                   kernel_init=kernel_initializer)(x)
     if self.probabilistic:
         sigma = nn.Dense(features=self.embedding_dim,
                          kernel_init=kernel_initializer)(x)
         sigma = nn.sigmoid(sigma)
         sigma = self.min_sigma + (self.max_sigma - self.min_sigma) * sigma
         eps = jax.random.normal(key, shape=sigma.shape)
         sample = mu + sigma * eps
     else:
         sigma = jnp.zeros(self.embedding_dim)
         sample = mu
     return DynamicsModelType(mu, sigma, sample)
Exemple #24
0
 def setup(self):
     self.distilbert = FlaxDistilBertModule(self.config, dtype=self.dtype)
     self.vocab_transform = nn.Dense(
         self.config.dim,
         dtype=self.dtype,
         kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
     )
     self.vocab_layer_norm = nn.LayerNorm(epsilon=1e-12, dtype=self.dtype)
     if self.config.tie_word_embeddings:
         self.vocab_projector = FlaxDistilBertLMDecoder(
             self.config,
             dtype=self.dtype,
         )
     else:
         self.vocab_projector = nn.Dense(
             self.config.vocab_size,
             dtype=self.dtype,
             kernel_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),
         )
    def __call__(self, inputs, is_training: bool):
        inputs = zero_pad_and_reshape(inputs)

        x = CvTSelfAttentionBlock(num_heads=self.num_heads,
                                  kernel_size=self.kernel_size,
                                  use_bias=self.use_bias,
                                  bn_momentum=self.bn_momentum,
                                  bn_epsilon=self.bn_epsilon,
                                  dtype=self.dtype)(inputs,
                                                    is_training=is_training)

        x = x + rearrange(inputs, 'b h w d -> b (h w) d')

        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = FFBlock(expand_ratio=self.expand_ratio,
                    activation_fn=self.activation_fn,
                    dtype=self.dtype)(y, is_training=is_training)
        output = x + y
        return output
Exemple #26
0
  def __call__(self,
               input_ids,
               type_ids,
               deterministic = False):
    """Applies EmbeddingLayer module.

    Args:
      input_ids: Batch of tokenized inputs of shape <int>[BATCH_SIZE,
        MAX_SEQ_LENGTH].
      type_ids: Ids partitioning input into different types.
      deterministic: Whether or not to apply dropout to output embeddings.

    Returns:
      Embedded tokens of shape <float>[BATCH_SIZE, MAX_SEQ_LENGTH, EMB_DIM].
    """
    word_embeddings = nn.Embed(
        num_embeddings=self.config.vocab_size,
        features=self.config.d_emb,
        embedding_init=default_kernel_init,
        name="word")(
            input_ids)
    position_embeddings = PositionalEncoding(
        max_seq_length=self.config.max_seq_length,
        posemb_init=default_kernel_init,
        name="position")(
            word_embeddings)
    type_embeddings = nn.Embed(
        num_embeddings=self.config.type_vocab_size,
        features=self.config.d_emb,
        embedding_init=default_kernel_init,
        name="type")(
            type_ids)

    embeddings = word_embeddings + position_embeddings + type_embeddings
    embeddings = nn.LayerNorm(
        epsilon=LAYER_NORM_EPSILON, name="layer_norm")(
            embeddings)
    embeddings = nn.Dense(
        self.config.d_model, name="hidden_mapping_in")(
            embeddings)
    return nn.Dropout(rate=self.config.dropout_rate)(
        embeddings, deterministic=deterministic)
Exemple #27
0
    def __call__(self, inputs, encoder_mask=None):
        """Applies Transformer model on the inputs.
    Args:
      inputs: input data
      inputs_positions: input subsequence positions for packed examples.
      encoder_mask: decoder self-attention mask.
    Returns:
      output of a transformer encoder.
    """
        cfg = self.config

        x = inputs
        # Input Encoder
        for lyr in range(cfg.num_layers):
            x = Encoder1DBlock(config=cfg,
                               name=f'encoderblock_{lyr}')(x, encoder_mask)

        encoded = nn.LayerNorm(dtype=cfg.dtype, name='encoder_norm')(x)

        return encoded
    def __call__(self, node_feature_embeddings, node_position_embeddings,
                 adjacency_mat, qstar):
        """Applies the TransformerEncoder module.

    Args:
      node_feature_embeddings: Embeddings representing nodes.
      node_position_embeddings: Embeddings representing node positions.
      adjacency_mat: Adjacency matrix over the nodes. Not used for now.
      qstar: float tensor of shape (num_of_nodes,) The optimal q weighting over
        the nodes of the graph, from the subgraph selection module.

    Returns:
      encoded: Encoded nodes, with extra Graph node at the end.
    """
        cfg = self.config
        x = node_feature_embeddings + node_position_embeddings

        # Add average weight to graph node for scale
        qstar = jnp.append(qstar, jnp.mean(qstar))

        # Multiply embeddings by node weights. => learn the agent model.
        x = x * qstar[Ellipsis, None]
        x = nn.Dropout(rate=cfg.dropout_rate)(x,
                                              deterministic=cfg.deterministic)

        x = x.astype(cfg.dtype)
        # TODO(gnegiar): Plot x here to check
        # Keep nodes with positive weights
        mask1d = qstar != 0
        encoder_mask = nn.attention.make_attention_mask(mask1d, mask1d)

        # Input Encoder
        for lyr in range(cfg.num_layers):
            x = Encoder1DBlock(config=cfg,
                               name=f"encoderblock_{lyr}")(x, encoder_mask)
            x = x * mask1d[Ellipsis, None]
            # TODO(gnegiar): Also plot x here
            # Possibly plot gradient norms per encoder layer
            # Plot attention weights?
        encoded = nn.LayerNorm(dtype=cfg.dtype, name="encoder_norm")(x)
        return encoded
    def __call__(self, inputs, is_training: bool):

        x = PatchEmbedBlock(patch_shape=self.patch_shape,
                            embed_dim=self.embed_dim,
                            dtype=self.dtype)(inputs)

        x = Encoder(num_layers=self.num_layers,
                    num_heads=self.num_heads,
                    expand_ratio=self.expand_ratio,
                    attn_dropout_rate=self.attn_dropout_rate,
                    dropout_rate=self.dropout_rate,
                    stoch_depth_rate=self.stoch_depth_rate,
                    layerscale_eps=self.layerscale_eps,
                    activation_fn=self.activation_fn)(x,
                                                      is_training=is_training)

        b = x.shape[0]
        cls_shape = (1, 1, self.embed_dim)
        cls_token = self.param('cls', nn.initializers.zeros, cls_shape)
        cls_token = jnp.tile(cls_token, [b, 1, 1])

        for _ in range(self.num_layers_token_only):
            cls_token = CAEncoderBlock(
                num_heads=self.num_heads,
                expand_ratio=self.expand_ratio,
                attn_dropout_rate=self.attn_dropout_rate,
                dropout_rate=self.dropout_rate,
                stoch_depth_rate=self.stoch_depth_rate,
                layerscale_eps=self.layerscale_eps,
                activation_fn=self.activation_fn,
                dtype=self.dtype)(x, cls_token, is_training=is_training)

        x = jnp.concatenate([cls_token, x], axis=1)
        x = nn.LayerNorm(dtype=self.dtype)(x)

        cls_token = x[:, 0]
        output = nn.Dense(features=self.num_classes,
                          use_bias=True,
                          dtype=self.dtype,
                          kernel_init=nn.initializers.zeros)(cls_token)
        return output
    def setup(self):
        self.dropout_layer = nn.Dropout(rate=self.config.dropout)

        embed_dim = self.config.d_model
        self.padding_idx = self.config.pad_token_id
        self.max_target_positions = self.config.max_position_embeddings
        self.embed_scale = math.sqrt(
            self.config.d_model) if self.config.scale_embedding else 1.0

        self.embed_tokens = nn.Embed(
            self.config.vocab_size,
            embed_dim,
            embedding_init=jax.nn.initializers.normal(self.config.init_std),
        )

        # XGLM is set up so that if padding_idx is specified then offset the embedding ids by 2
        # and adjust num_embeddings appropriately. Other models don't have this hack
        self.offset = 2
        self.embed_positions = create_sinusoidal_positions(
            self.config.max_position_embeddings + self.offset, embed_dim)
        self.layers = FlaxXGLMDecoderLayerCollection(self.config, self.dtype)
        self.layer_norm = nn.LayerNorm(dtype=self.dtype, epsilon=1e-05)