Beispiel #1
0
    def __call__(self, x):
        # Helper macro.
        R_ = lambda hidden_: ResidualUnit(hidden_features=hidden_,
                                          norm=self.norm,
                                          training=self.training,
                                          activation=nn.gelu)
        # First filter to make features.
        h = nn.Conv(features=self.hidden * self.alpha,
                    use_bias=False,
                    kernel_size=(3, 3),
                    kernel_init=INITS[self.kernel_init])(x)
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        # 2 stages of continuous segments:
        h = ResidualStitch(hidden_features=self.hidden * self.alpha,
                           output_features=self.hidden * self.alpha,
                           strides=(1, 1),
                           norm=self.norm,
                           training=self.training,
                           activation=nn.gelu)(h)
        h = StatefulContinuousBlock(R=R_(self.hidden * self.alpha),
                                    scheme=self.scheme,
                                    n_step=self.n_step,
                                    n_basis=self.n_basis,
                                    basis=self.basis,
                                    training=self.training)(h)

        # Pool and linearly classify:
        h = NORMS[self.norm](use_running_average=not self.training)(h)
        h = nn.gelu(h)
        h = nn.avg_pool(h, window_shape=(8, 8), strides=(8, 8))
        h = h.reshape((h.shape[0], -1))
        h = nn.Dense(features=self.n_classes)(h)
        return nn.log_softmax(h)  # no softmax
Beispiel #2
0
 def __call__(self, x):
     Conv1x1_ = partial(Conv1x1, precision=self.conv_precision)
     Conv3x3_ = partial(Conv3x3 if self.use_3x3 else Conv1x1,
                        precision=self.conv_precision)
     x_ = Conv1x1_(self.middle_width)(nn.gelu(x))
     x_ = Conv3x3_(self.middle_width)(nn.gelu(x_))
     x_ = Conv3x3_(self.middle_width)(nn.gelu(x_))
     x_ = Conv1x1_(self.out_width,
                   kernel_init=lecun_normal(self.last_scale))(nn.gelu(x_))
     out = x + x_ if self.residual else x_
     if self.down_rate > 1:
         window_shape = 2 * (self.down_rate, )
         out = nn.avg_pool(out, window_shape, window_shape)
     return out
 def __call__(self, x):
     x = nn.Conv(features=28, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], -1))
     mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x)
     logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x)
     return mean_x, logvar_x
 def __call__(self,
              inputs: jnp.ndarray,
              deterministic: Optional[bool] = None) -> jnp.ndarray:
   """Applies BatchEnsemble MlpBlock module."""
   deterministic = nn.module.merge_param("deterministic", self.deterministic,
                                         deterministic)
   dtype = self.dtype or inputs.dtype
   inputs = jnp.asarray(inputs, self.dtype)
   out_dim = self.out_dim or inputs.shape[-1]
   x = ed.nn.DenseBatchEnsemble(
       self.mlp_dim,
       self.ens_size,
       activation=None,
       use_ensemble_bias=self.use_bias,
       alpha_init=ed.nn.utils.make_sign_initializer(self.random_sign_init),
       gamma_init=ed.nn.utils.make_sign_initializer(self.random_sign_init),
       kernel_init=self.kernel_init,
       bias_init=self.bias_init,
       dtype=dtype)(inputs)
   x = nn.gelu(x)
   x = nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)(x)
   output = ed.nn.DenseBatchEnsemble(
       out_dim,
       self.ens_size,
       activation=None,
       use_ensemble_bias=self.use_bias,
       alpha_init=ed.nn.utils.make_sign_initializer(self.random_sign_init),
       gamma_init=ed.nn.utils.make_sign_initializer(self.random_sign_init),
       kernel_init=self.kernel_init,
       bias_init=self.bias_init,
       dtype=dtype)(x)
   output = nn.Dropout(
       rate=self.dropout_rate, deterministic=deterministic)(
           output)
   return output
Beispiel #5
0
    def __call__(self, inputs, temb, deterministic):
        """Applies Transformer MlpBlock module."""
        cfg = self.config
        actual_out_dim = (inputs.shape[-1]
                          if self.out_dim is None else self.out_dim)
        x = nn.Dense(cfg.mlp_dim,
                     dtype=cfg.dtype,
                     kernel_init=cfg.kernel_init,
                     bias_init=cfg.bias_init)(inputs)

        # Add in the time embedding if applicable.
        if temb is not None:
            x += nn.Dense(cfg.mlp_dim,
                          dtype=cfg.dtype,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init)(temb)[:, None, :]
        x = nn.gelu(x)
        x = nn.Dropout(rate=cfg.dropout_rate)(x, deterministic=deterministic)
        output = nn.Dense(actual_out_dim,
                          dtype=cfg.dtype,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init)(x)
        output = nn.Dropout(rate=cfg.dropout_rate)(output,
                                                   deterministic=deterministic)
        return output
 def __call__(self, z):
     shape_before_flattening, flatten_out_size = self.flatten_enc_shape()
     #print(shape_before_flattening, flatten_out_size)
     x = nn.Dense(flatten_out_size, name='fc1')(z)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
     x = nn.ConvTranspose(features=32, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=28, kernel_size=(3, 3),
                          strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.ConvTranspose(features=1, kernel_size=(3, 3), strides=(2, 2))(x)
     return x
Beispiel #7
0
    def __call__(
        self,
        encoded_input: Array,
        mlm_target_positions: Array,
        shared_embedding: Array,
    ) -> Array:
        """Perform masked language modeling scoring.

    Args:
      encoded_input: [bsz, n_tokens, hidden_size].
      mlm_target_positions: [bsz, max_mlm_targets] positions of mlm targets in
        passage.
      shared_embedding: [vocab_size, hidden_size] word embedding array, shared
        with initial embedding.

    Returns:
      Array of masked language modeling logits.
    """

        target_encodings = jut.matmul_slice(encoded_input,
                                            mlm_target_positions)
        target_encodings = self.dense(target_encodings)
        target_encodings = nn.gelu(target_encodings)
        target_encodings = self.layer_norm(target_encodings)

        mlm_logits = self.embedding_dense.apply(
            {'params': {
                'kernel': shared_embedding.T
            }}, target_encodings)
        mlm_logits = mlm_logits + self.bias

        return mlm_logits
Beispiel #8
0
 def __call__(self, x, deterministic=False):
     features = x.shape[-1]
     x = nn.Dense(features * self.mult)(x)
     x = nn.gelu(x)
     x = nn.Dropout(self.dropout)(x, deterministic=deterministic)
     x = nn.Dense(features)(x)
     return x
Beispiel #9
0
    def __call__(self,
                 input_ids,
                 input_mask,
                 type_ids,
                 masked_lm_positions,
                 masked_lm_labels,
                 masked_lm_weights,
                 next_sentence_labels,
                 deterministic=False):
        """Applies pre-training model on inputs.

    Args:
      input_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] tokenized inputs.
      input_mask: <bool>[BATCH_SIZE, MAX_SEQ_LENGTH] mask separating actual
        inputs from padding. Only used by BERT.
      type_ids: <int>[BATCH_SIZE, MAX_SEQ_LENGTH] Ids partitioning input into
        different types.
      masked_lm_positions: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] indices
        indicating which inputs are masked.
      masked_lm_labels: <int>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] true labels
        for masked inputs.
      masked_lm_weights: <float>[BATCH_SIZE, MAX_PREDICTIONS_PER_SEQ] relative
        weighting for masked inputs.
      next_sentence_labels: <int>[BATCH_SIZE, 1] Labels for next sentence
        prediction task.
      deterministic: Whether or not to apply dropout to input.

    Returns:
      Loss and metrics for given inputs.
    """
        sequence_output, pooled_output = EncoderModel(
            self.config, random_seed=self.random_seed,
            name="encoder")(input_ids,
                            input_mask,
                            type_ids,
                            deterministic=deterministic)

        masked_lm_output = layers.gather(sequence_output, masked_lm_positions)
        masked_lm_output = nn.Dense(self.config.d_emb,
                                    kernel_init=default_kernel_init,
                                    name="predictions_dense")(masked_lm_output)
        masked_lm_output = nn.gelu(masked_lm_output)
        masked_lm_output = nn.LayerNorm(
            epsilon=LAYER_NORM_EPSILON,
            name="predictions_layer_norm")(masked_lm_output)
        masked_lm_logits = layers.OutputProjection(
            kernel=self._get_embedding_table(),
            name="predictions_output")(masked_lm_output)

        next_sentence_logits = layers.OutputProjection(
            n_out=2, kernel_init=default_kernel_init,
            name="classification")(pooled_output)

        return _compute_pretraining_metrics(masked_lm_logits,
                                            next_sentence_logits,
                                            masked_lm_labels,
                                            masked_lm_weights,
                                            next_sentence_labels)
Beispiel #10
0
  def flatten_enc_shape(self):
    x = jnp.ones([1, 64, 64, 3], jnp.float32)

    # Build Encoder
    for h_dim in self.hidden_dims:
      x = nn.Conv(features=h_dim, kernel_size=(3, 3), strides=(2,2))(x)
      x = nn.gelu(x)

    return x.shape, int(np.prod(x.shape))
Beispiel #11
0
    def __call__(
        self,
        encoded_input: Array,
        retrieval_values: Array,
        retrieval_scores: Array,
        mention_batch_positions: Array,
        mention_start_positions: Array,
        mention_end_positions: Array,
        mention_mask: Array,
        deterministic: bool,
    ) -> Array:

        # Generate mention values from input representation
        mention_start_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_start_positions))
        mention_end_encodings = jut.matmul_2d_index_select(
            encoded_input, (mention_batch_positions, mention_end_positions))

        passage_mention_values = self.value_projection(
            jnp.concatenate((mention_start_encodings, mention_end_encodings),
                            axis=-1))
        k_retrieval = retrieval_scores.shape[-1]
        passage_mention_values = jnp.expand_dims(passage_mention_values,
                                                 axis=1)
        passage_mention_values = jnp.tile(passage_mention_values,
                                          (1, k_retrieval, 1))

        # Generate concatenated values of shape [mentions, k, 2 * retrieval_dim]
        concat_values = jnp.concatenate(
            (passage_mention_values, retrieval_values), axis=-1)

        # MLP over concatenation mention value and individual retrieved value
        concat_values = nn.gelu(self.concat_mlp(concat_values))
        concat_values = self.concat_dense(concat_values)
        concat_values = self.concat_dropout(concat_values, deterministic)

        # Additional MLP layers
        for concat_mlp_layer in self.additional_concat_mlp_layers:
            concat_values = concat_mlp_layer(concat_values, deterministic)

        pooled_values = jnp.einsum('qk,qkd->qd', retrieval_scores,
                                   concat_values)

        # MLP layers applied to pooled retrieval values
        for pooled_mlp_layer in self.pooled_mlp_layers:
            pooled_values = pooled_mlp_layer(pooled_values, deterministic)
        pooled_values = pooled_values * mention_mask.reshape(-1, 1)

        encoded_input = jut.matmul_2d_index_add(
            encoded_input, (mention_batch_positions, mention_start_positions),
            pooled_values)

        encoded_input = self.layer_norm(encoded_input)

        return encoded_input
Beispiel #12
0
 def __call__(self, inputs, deterministic):
     h = nn.Dense(features=_resolve(self.hidden_params, inputs.shape[-1]),
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
                      inputs)
     h = nn.gelu(h)
     h = nn.Dense(features=_resolve(self.out_params, inputs.shape[-1]),
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
                      h)
     return h
Beispiel #13
0
  def __call__(self, x):
    # Build Encoder
    for h_dim in self.hidden_dims:
      x = nn.Conv(features=h_dim, kernel_size=(3, 3), strides=(2,2), padding="valid")(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)

    x = x.reshape((x.shape[0], -1))
    mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x)
    logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x)
    return mean_x, logvar_x
Beispiel #14
0
 def __call__(self, x):
     # if temporal, x is [b, l, d]
     # if spatial, x is  [b, h, w, d]
     shortcut = x
     x = nn.normalization.LayerNorm()(x)
     x = nn.Dense(features=self.ffn_dim)(x)
     x = nn.gelu(x)
     x = SpatialGatingUnit()(x)
     x = nn.Dense(features=self.model_dim)(x)
     x = x + shortcut
     return x
Beispiel #15
0
    def __call__(self, x):
        dim_in, mult = x.shape[-1], self.mult

        norm = nn.LayerNorm()
        to_intermediate = nn.Dense(features=dim_in * mult)
        to_out = nn.Dense(features=dim_in)

        x = norm(x)
        x = to_intermediate(x)
        x = nn.gelu(x)
        x = to_out(x)
        return x
Beispiel #16
0
    def __call__(self, x, train=True):
        """Applies Transformer MlpBlock module."""
        inits = dict(
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.normal(stddev=1e-6),
        )

        d = x.shape[2]
        x = nn.Dense(self.mlp_dim or 4 * d, **inits)(x)
        x = nn.gelu(x)
        x = nn.Dropout(rate=self.dropout)(x, train)
        x = nn.Dense(d, **inits)(x)
        return x
Beispiel #17
0
 def __call__(self, inputs, train):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[
         -1] if self.out_dim is None else self.out_dim
     x = nn.Dense(self.mlp_dim,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(inputs)
     x = nn.gelu(x)
     x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
     output = nn.Dense(actual_out_dim,
                       kernel_init=self.kernel_init,
                       bias_init=self.bias_init)(x)
     output = nn.Dropout(rate=self.dropout_rate,
                         deterministic=not train)(output)
     return output
Beispiel #18
0
    def __call__(self, x: Array, deterministic: bool) -> Array:
        """Applies MLP block update.

    Args:
      x: [..., input_dim].
      deterministic: don't apply dropout if true.

    Returns:
      Updated array x of same shape.
    """
        update = nn.gelu(self.mlp(x))
        update = self.dense(update)
        update = self.dropout(update, deterministic=deterministic)
        x = self.layer_norm(x + update)

        return x
Beispiel #19
0
  def __call__(self, z):
    shape_before_flattening, flatten_out_size = self.flatten_enc_shape()

    x = nn.Dense(flatten_out_size, name='fc1')(z)
    x = x.reshape((x.shape[0], *shape_before_flattening[1:]))
    
    hidden_dims = self.hidden_dims[::-1]
    # Build Decoder
    for h_dim in range(len(hidden_dims)-1):
      x = nn.ConvTranspose(features=hidden_dims[h_dim], kernel_size=(3, 3), strides=(2,2))(x)
      x = nn.GroupNorm()(x)
      x = nn.gelu(x)
    
    x = nn.ConvTranspose(features=3, kernel_size=(3, 3), strides=(2,2))(x)
    x = nn.sigmoid(x)
    return x
Beispiel #20
0
 def __call__(self, inputs):
   """Applies MLP block of dense layers."""
   cfg = self.config
   actual_out_dim = (inputs.shape[-1] if self.out_dim is None
                     else self.out_dim)
   x = nn.Dense(cfg.mlp_dim,
                dtype=cfg.dtype,
                kernel_init=cfg.kernel_init,
                bias_init=cfg.bias_init)(inputs)
   x = nn.gelu(x)
   x = nn.Dropout(rate=cfg.dropout_rate)(
       x, deterministic=cfg.deterministic)
   output = nn.Dense(actual_out_dim,
                     dtype=cfg.dtype,
                     kernel_init=cfg.kernel_init,
                     bias_init=cfg.bias_init)(x)
   output = nn.Dropout(rate=cfg.dropout_rate)(
       output, deterministic=cfg.deterministic)
   return output
Beispiel #21
0
 def __call__(self, inputs, *, deterministic):
     """Applies Transformer MlpBlock module."""
     actual_out_dim = inputs.shape[
         -1] if self.out_dim is None else self.out_dim
     x = nn.Dense(features=self.mlp_dim,
                  dtype=self.dtype,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
                      inputs)
     x = nn.gelu(x)
     x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
     output = nn.Dense(features=actual_out_dim,
                       dtype=self.dtype,
                       kernel_init=self.kernel_init,
                       bias_init=self.bias_init)(  # pytype: disable=wrong-arg-types
                           x)
     output = nn.Dropout(rate=self.dropout_rate)(
         output, deterministic=deterministic)
     return output
    def __call__(self, inputs):

        cfg = self.config

        if self.inner:
            dim = cfg.inner_dim
            r = cfg.inner_r
        else:
            dim = cfg.outer_dim
            r = cfg.outer_r

        x = nn.Dense(dim * r,
                     dtype=cfg.dtype,
                     kernel_init=cfg.kernel_init,
                     bias_init=cfg.bias_init)(inputs)
        x = nn.gelu(x)
        output = nn.Dense(dim,
                          dtype=cfg.dtype,
                          kernel_init=cfg.kernel_init,
                          bias_init=cfg.bias_init)(x)
        return output
 def __call__(self, x):
     y = nn.Dense(self.mlp_dim)(x)
     y = nn.gelu(y)
     return nn.Dense(x.shape[-1])(y)