Ejemplo n.º 1
0
 def __call__(self, inputs, train: bool = False):
   x = ASPP([12, 24, 36], name='ASPP')(inputs)
   x = nn.Conv(256, (3, 3), padding='SAME', use_bias=False, name="conv1")(x)
   x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
   x = nn.relu(x)
   x = nn.Conv(self.num_classes, (1, 1), padding='VALID', use_bias=True, name="conv2")(x)
   return x
Ejemplo n.º 2
0
 def __call__(self, x):
     initializer = nn.initializers.variance_scaling(scale=1.0 /
                                                    jnp.sqrt(3.0),
                                                    mode='fan_in',
                                                    distribution='uniform')
     x = x.astype(jnp.float32) / 255.
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Dense(features=self.num_actions * self.num_atoms,
                  kernel_init=initializer)(x)
     logits = x.reshape((self.num_actions, self.num_atoms))
     probabilities = nn.softmax(logits)
     q_values = jnp.mean(logits, axis=1)
     return atari_lib.RainbowNetworkType(q_values, logits, probabilities)
Ejemplo n.º 3
0
    def __call__(self, x):
        initializer = nn.initializers.xavier_uniform()
        conv_out = nn.Conv(features=self.num_ch,
                           kernel_size=(3, 3),
                           strides=1,
                           kernel_init=initializer,
                           padding='SAME')(x)
        if self.use_max_pooling:
            conv_out = nn.max_pool(conv_out,
                                   window_shape=(3, 3),
                                   padding='SAME',
                                   strides=(2, 2))

        for _ in range(self.num_blocks):
            block_input = conv_out
            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(features=self.num_ch,
                               kernel_size=(3, 3),
                               strides=1,
                               padding='SAME')(conv_out)
            conv_out = nn.relu(conv_out)
            conv_out = nn.Conv(features=self.num_ch,
                               kernel_size=(3, 3),
                               strides=1,
                               padding='SAME')(conv_out)
            conv_out += block_input

        return conv_out
Ejemplo n.º 4
0
 def __call__(self, x, support=None):
     initializer = nn.initializers.variance_scaling(scale=1.0 /
                                                    jax.numpy.sqrt(3.0),
                                                    mode='fan_in',
                                                    distribution='uniform')
     if not self.inputs_preprocessed:
         x = networks.preprocess_atari_inputs(x)
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     return x
    def __call__(self, inputs):
        x = inputs
        input_filters = x.shape[-1]

        # Expand (block_id controls this block in the keras implementation).
        x = nn.Conv(_depth(input_filters * self.expansion),
                    kernel_size=(1, 1),
                    use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = self.activation(x)

        if self.stride == 2:
            x = zero_pad_2d(correct_pad(x, self.kernel_size))(x)
        x = DepthwiseConv2D(kernel_size=(self.kernel_size, self.kernel_size),
                            strides=(self.stride, self.stride),
                            padding="same" if self.stride == 1 else "valid",
                            use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = self.activation(x)

        if self.se_ratio:
            x = SEBlock(self.se_ratio)(x)

        x = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)

        if self.stride == 1 and input_filters == self.filters:
            x = inputs + x
        return x
Ejemplo n.º 6
0
    def __call__(self, x, temb=None, train=True):
        if self.activate_before_residual:
            x = activation(x, train, name='init_bn')
            orig_x = x
        else:
            orig_x = x

        block_x = x
        if not self.activate_before_residual:
            block_x = activation(block_x, train, name='init_bn')

        block_x = nn.Conv(self.channels, (3, 3),
                          self.strides,
                          padding='SAME',
                          use_bias=False,
                          kernel_init=conv_kernel_init_fn,
                          name='conv1')(block_x)

        if temb is not None:
            block_x += nn.Dense(self.channels)(nn.swish(temb))[:, None,
                                                               None, :]
        block_x = activation(block_x, train=train, name='bn_2')
        block_x = nn.Conv(self.channels, (3, 3),
                          padding='SAME',
                          use_bias=False,
                          kernel_init=conv_kernel_init_fn,
                          name='conv2')(block_x)

        return _output_add(block_x, orig_x)
Ejemplo n.º 7
0
    def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = linen.Conv(32, [3, 3], strides=[2, 2])(x)
        x = linen.Dropout(0.05, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = linen.Conv(64, [3, 3], strides=[2, 2])(x)
        x = linen.BatchNorm(use_running_average=not training)(x)
        x = linen.Dropout(0.1, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = linen.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = linen.Dense(10)(x)

        return x
Ejemplo n.º 8
0
  def __call__(self, x):
    """Define the convolutional network architecture.

    Architecture originates from "Human-level control through deep reinforcement
    learning.", Nature 518, no. 7540 (2015): 529-533.
    Note that this is different than the one from  "Playing atari with deep
    reinforcement learning." arxiv.org/abs/1312.5602 (2013)

    Network is used to both estimate policy (logits) and expected state value;
    in other words, hidden layers' params are shared between policy and value
    networks, see e.g.:
    github.com/openai/baselines/blob/master/baselines/ppo1/cnn_policy.py
    """
    dtype = jnp.float32
    x = x.astype(dtype) / 255.
    x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), name='conv1',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), name='conv2',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), name='conv3',
                dtype=dtype)(x)
    x = nn.relu(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=512, name='hidden', dtype=dtype)(x)
    x = nn.relu(x)
    logits = nn.Dense(features=self.num_outputs, name='logits', dtype=dtype)(x)
    policy_log_probabilities = nn.log_softmax(logits)
    value = nn.Dense(features=1, name='value', dtype=dtype)(x)
    return policy_log_probabilities, value
Ejemplo n.º 9
0
    def __call__(self, x):
        """Define the model architecture.

    Network is used to both estimate policy (logits) and expected state value;
    in other words, hidden layers' params are shared between policy and value
    networks.

    Args:
      x: input of shape N, H, W(1)

    Returns:
      policy_log_probabilities: logits
      value: state value
    """
        x = x.astype(self.dtype)
        x = nn.Conv(features=self.chan1,
                    kernel_size=[3],
                    strides=1,
                    name='conv1',
                    dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Conv(features=self.chan2,
                    kernel_size=[3],
                    strides=1,
                    name='conv2',
                    dtype=self.dtype)(x)
        x = nn.relu(x)

        x = x.reshape((x.shape[0], -1))  # flatten
        outputs = nn.Dense(features=self.num_actions,
                           name='outputs',
                           dtype=self.dtype)(x)
        return outputs
Ejemplo n.º 10
0
 def __call__(self, x):
     initializer = nn.initializers.xavier_uniform()
     if not self.inputs_preprocessed:
         x = preprocess_atari_inputs(x)
     x = nn.Conv(features=32,
                 kernel_size=(8, 8),
                 strides=(4, 4),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(4, 4),
                 strides=(2, 2),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = nn.Conv(features=64,
                 kernel_size=(3, 3),
                 strides=(1, 1),
                 kernel_init=initializer)(x)
     x = nn.relu(x)
     x = x.reshape((-1))  # flatten
     x = nn.Dense(features=512, kernel_init=initializer)(x)
     x = nn.relu(x)
     q_values = nn.Dense(features=self.num_actions,
                         kernel_init=initializer)(x)
     return atari_lib.DQNNetworkType(q_values)
Ejemplo n.º 11
0
  def __call__(self, inputs, train: bool = False):
    inter_channels = np.shape(inputs)[-1] // 4
    x = nn.Conv(inter_channels, (3, 3), padding='SAME', use_bias=False, name="conv1")(inputs)
    x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
    x = nn.relu(x)
    x = nn.Dropout(0.1)(x, deterministic=not train)
    x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=True, name="conv2")(x)

    return x
Ejemplo n.º 12
0
 def __call__(self, x):
     h = NORMS[self.norm](use_running_average=not self.training)(x)
     h = self.activation(h)
     h = nn.Conv(self.hidden_features, (3, 3), use_bias=self.use_bias,
                 kernel_init=INITS[self.kernel_init])(h)
     h = NORMS[self.norm](use_running_average=not self.training)(h)
     h = self.activation(h)
     h = nn.Conv(x.shape[-1], (3, 3), use_bias=self.use_bias,
                 kernel_init=INITS[self.kernel_init])(h)
     return self.epsilon * h
Ejemplo n.º 13
0
    def __call__(self, x, t_embed):
        """Apply the residual block.

    Args:
      x: Inputs of shape [batch, <spatial>, features].
      t_embed: Embedded time steps of shape [batch, dim].

    Returns:
      Mapped inputs of shape [batch, <spatial>, features] for the output and
      skip connections.
    """
        in_features = x.shape[-1]
        if in_features != self.features:
            raise ValueError(
                f'DiffWave ResBlock requires the same number of input ({in_features})'
                f'and output ({self.features}) features.')

        h = x
        if t_embed is not None:
            # Project time step embedding.
            t_embed = nn.Dense(in_features,
                               name='step_proj')(self.activation(t_embed))
            # Reshape to [batch, 1, ..., 1, in_features] for broadcast.
            t_embed = jnp.reshape(t_embed,
                                  (-1, ) + (1, ) * len(self.kernel_size) +
                                  (in_features, ))
            h += t_embed

        # Dilated gated conv.
        u = layers.CausalConv(self.features,
                              self.kernel_size,
                              kernel_dilation=self.kernel_dilation,
                              kernel_init=self.kernel_init,
                              padding='VALID' if self.is_causal else 'SAME',
                              is_causal=self.is_causal,
                              name='dilated_tanh')(h)
        v = layers.CausalConv(self.features,
                              self.kernel_size,
                              kernel_dilation=self.kernel_dilation,
                              kernel_init=self.kernel_init,
                              padding='VALID' if self.is_causal else 'SAME',
                              is_causal=self.is_causal,
                              name='dilated_sigmoid')(h)
        y = jax.nn.tanh(u) * jax.nn.sigmoid(v)

        # Residual and skip convs.
        residual = nn.Conv(self.features, (1, ) * len(self.kernel_size),
                           kernel_init=self.kernel_init,
                           name='residual')(y)
        skip = nn.Conv(self.skip_features or self.features,
                       (1, ) * len(self.kernel_size),
                       kernel_init=self.kernel_init,
                       name='skip')(y)

        return (x + residual) / np.sqrt(2.), skip
Ejemplo n.º 14
0
    def setup(self):

        self.layer1 = nn.Conv(self.in_features, (3, 3), strides=1)
        self.group_l1 = nn.normalization.GroupNorm(3)
        self.mid = UnitUnet(self.d, 16, 16)
        self.straight1 = nn.Conv(self.mid.outfeatures + self.layer1.features,
                                 (3, 3),
                                 strides=(1, 1))
        self.group_straight1 = nn.normalization.GroupNorm(
            self.mid.outfeatures + self.layer1.features)
        self.straight2 = nn.Conv(self.out_features, (3, 3), strides=(1, 1))
Ejemplo n.º 15
0
 def __call__(self, x):
     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # flatten
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=10)(x)
     return x
Ejemplo n.º 16
0
 def __call__(self, x):
     x = nn.Conv(features=32, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = nn.Conv(features=64, kernel_size=(3, 3))(x)
     x = nn.relu(x)
     x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
     x = x.reshape((x.shape[0], -1))  # Flatten
     x = nn.Dense(features=256)(x)
     x = nn.relu(x)
     x = nn.Dense(features=10)(x)  # There are 10 classes in MNIST
     x = nn.log_softmax(x)
     return x
Ejemplo n.º 17
0
    def __call__(self, x, t, mask, train, context=None):
        """Apply the WaveDiff network.

    Args:
      x: Inputs of shape [batch, <spatial>, features].
      t: Time steps of shape [batch].
      mask: Array of the same shape as `x` giving the auto-regressive mask.
      train: If True, the model is ran in training. *Not* used in this
        architecture.
      context: Unused.

    Returns:
      Mapped inputs of shape [batch, <spatial>, skip_features]
    """
        assert context is None

        # Sinusoidal features + MLP for time step embedding.
        # Note: this differs from the DiffWave embedding in several ways:
        # * Time embeddings have different dimensionality: 128-512-512
        #   vs 256-1024-1024.
        # * First convlution has kernel size 3 instead of 1.
        h, t_embed = input_embedding.InputProcessingAudio(
            num_classes=self.num_classes,
            num_channels=self.features,
            max_time=self.max_time,
            is_causal=self.is_causal)(x, t, mask, train)
        del x, t, mask

        h = nn.relu(h)
        h = ResGroup(num_blocks=self.num_blocks,
                     features=self.features,
                     skip_features=self.skip_features,
                     kernel_size=self.kernel_size,
                     dilation_cycle=self.dilation_cycle,
                     kernel_init=self.kernel_init,
                     is_causal=self.is_causal,
                     name='res_group')(h, t_embed)

        # Final convolution.
        h = nn.Conv(features=self.skip_features or self.features,
                    kernel_size=(1, ) * len(self.kernel_size),
                    kernel_init=self.kernel_init,
                    name='flower_conv')(h)
        h = nn.relu(h)
        if self.output_features:
            h = nn.Conv(features=self.output_features,
                        kernel_size=(1, ) * len(self.kernel_size),
                        kernel_init=nn.initializers.zeros,
                        name='class_conv')(h)

        return h
Ejemplo n.º 18
0
    def __call__(self, x):

        if self.minimalistic:
            kernel = 3
            activation = relu
            se_ratio = None
        else:
            kernel = 5
            activation = hard_swish
            se_ratio = 0.25

        # Input processing (shared between small and large variants).
        x = x / 255
        x = nn.Conv(features=16,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    use_bias=False)(x)
        x = activation(x)

        # Main network
        get_args = _get_large_args if self.large else _get_small_args
        for args in get_args(kernel, activation, se_ratio, self.alpha):
            x = ResidualInvertedBottleneck(*args,
                                           batch_norm=self.batch_norm)(x)

        # Last stages (shared between small and large variants).
        x = nn.Conv(features=_depth(x.shape[-1] * 6),
                    kernel_size=(1, 1),
                    use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = activation(x)

        if self.alpha > 1.0:
            last_point_features = _depth(self.last_point_features * self.alpha)
        else:
            last_point_features = self.last_point_features
        x = nn.Conv(features=last_point_features,
                    kernel_size=(1, 1),
                    use_bias=True)(x)
        x = activation(x)

        x = global_average_pooling(x)
        x = x.reshape((x.shape[0], 1, 1, last_point_features))

        x = nn.Conv(features=self.classes, kernel_size=(1, 1))(x)
        x = flatten(x)
        # x = self.classifier_activation(x)

        return x
 def __call__(self, x):
     x = nn.Conv(features=28, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(28)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(2, 2))(x)
     x = nn.GroupNorm(32)(x)
     x = nn.gelu(x)
     x = x.reshape((x.shape[0], -1))
     mean_x = nn.Dense(self.latent_dim, name='fc2_mean')(x)
     logvar_x = nn.Dense(self.latent_dim, name='fc2_logvar')(x)
     return mean_x, logvar_x
Ejemplo n.º 20
0
    def __call__(self, x, train: bool = True):
        # Common arguments
        kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }

        # x = np.reshape(x, (64, 64, 1))
        x = x[..., None]

        # Layer 1
        x = nn.Conv(features=64, **kwargs)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 2
        x = nn.Conv(features=128, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 3
        x = nn.Conv(features=256, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 4
        x = nn.Conv(features=512, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 5
        x = nn.Conv(features=4096,
                    kernel_size=(4, 4),
                    strides=(1, 1),
                    padding='VALID',
                    use_bias=False,
                    kernel_init=he_normal())(x)
        x = nn.leaky_relu(x, 0.2)

        # Flatten
        x = x.flatten()

        # Predict latent variables
        z_mean = nn.Dense(features=self.zdim)(x)
        z_logvar = nn.Dense(features=self.zdim)(x)

        return z_mean, z_logvar
Ejemplo n.º 21
0
    def __call__(self, x):
        x = nn.Conv(features=16, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = nn.Conv(features=16, kernel_size=(3, 3))(x)
        x = nn.relu(x)
        x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(features=256)(x)

        # return intermediate output
        x = TestDense()(x)

        x = nn.Dense(features=N_CLASSES)(x)
        return x
Ejemplo n.º 22
0
 def setup(self):
     if (self.d == 0):
         self.latent = nn.Conv(self.outfeatures, (1, 1), strides=1)
         self.group_latent = nn.normalization.GroupNorm(self.ngroup)
     else:
         self.conv = nn.Conv(self.outfeatures, (3, 3), strides=2)
         self.group_conv = nn.normalization.GroupNorm(self.ngroup)
         self.mid = UnitUnet(self.d - 1, self.outfeatures * 2, self.ngroup)
         self.deconv = lambda x: jax.image.resize(x, (x.shape[0], x.shape[
             1] * 2, x.shape[2] * 2, x.shape[3]),
                                                  method='bilinear')
         self.conv2 = nn.Conv(self.outfeatures, (3, 3), strides=1)
         self.group_deconv = nn.normalization.GroupNorm(self.ngroup)
         self.conv3 = nn.Conv(self.outfeatures, (3, 3), strides=1)
         self.group_deconv3 = nn.normalization.GroupNorm(self.ngroup)
Ejemplo n.º 23
0
 def __call__(self, x):
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1))(x)
     x = activation(x)
     x = jnp.reshape(x, (x.shape[0], -1))
     x = nn.Dense(10)(x)
     x = nn.log_softmax(x)
     return x
Ejemplo n.º 24
0
 def __call__(self, x, with_classifier=True):
   x = nn.Conv(features=32, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = nn.Conv(features=64, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = x.reshape((x.shape[0], -1))  # flatten
   x = nn.Dense(features=256)(x)
   x = nn.relu(x)
   if not with_classifier:
     return x
   x = nn.Dense(features=10)(x)
   x = nn.log_softmax(x)
   return x
Ejemplo n.º 25
0
    def setup(self):
        embed_dim = self.config.hidden_size
        image_size = self.config.image_size
        patch_size = self.config.patch_size

        self.class_embedding = self.param(
            "class_embedding", jax.nn.initializers.normal(stddev=0.02),
            (embed_dim, ))

        self.patch_embedding = nn.Conv(
            embed_dim,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding="VALID",
            use_bias=False,
            dtype=self.dtype,
            kernel_init=jax.nn.initializers.normal(),
        )

        self.num_patches = (image_size // patch_size)**2
        num_positions = self.num_patches + 1
        self.position_embedding = nn.Embed(
            num_positions,
            embed_dim,
            embedding_init=jax.nn.initializers.normal())
        self.position_ids = jnp.expand_dims(jnp.arange(0,
                                                       num_positions,
                                                       dtype="i4"),
                                            axis=0)
Ejemplo n.º 26
0
 def __call__(self, x, with_classifier=True):
   x = nn.Conv(features=32, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   x = nn.Conv(features=64, kernel_size=(3, 3))(x)
   x = nn.relu(x)
   x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
   # TODO: replace np.prod(x.shape[1:]) with -1 once we fix shape_polymorphism
   x = x.reshape((x.shape[0], np.prod(x.shape[1:])))  # flatten
   x = nn.Dense(features=256)(x)
   x = nn.relu(x)
   if not with_classifier:
     return x
   x = nn.Dense(features=10)(x)
   x = nn.log_softmax(x)
   return x
Ejemplo n.º 27
0
    def setup(self):
        self.backbone = Sequential(ResNet50(n_classes=1).layers[0:18])
        self.feature_conv = nn.Conv(
            features=self.config.mlp_dim,
            kernel_size=(1, 1))  # 1D conv to convert resnet outputs down
        self.transformer = Transformer(self.config)
        self.linear_class = nn.Dense(
            self.config.output_dim + 1,
            dtype=self.config.dtype,
            kernel_init=self.config.kernel_init,
            bias_init=self.config.bias_init)  # +1 is for the empty class
        self.linear_bbox = nn.Dense(4,
                                    dtype=self.config.dtype,
                                    kernel_init=self.config.kernel_init,
                                    bias_init=self.config.bias_init)

        self.query_pos = self.param('queries',
                                    nn.initializers.uniform(scale=1),
                                    (100, self.config.mlp_dim),
                                    self.config.dtype)
        self.row_embed = self.param('row_embed',
                                    nn.initializers.uniform(scale=1),
                                    (50, self.config.mlp_dim // 2),
                                    self.config.dtype)
        self.col_embed = self.param('col_embed',
                                    nn.initializers.uniform(scale=1),
                                    (50, self.config.mlp_dim // 2),
                                    self.config.dtype)
Ejemplo n.º 28
0
 def __call__(self, inputs, train: bool = False):
   _d = max(1, self.dilation)
   x = jnp.pad(inputs, [(0, 0), (_d, _d), (_d, _d), (0, 0)], 'constant', (0, 0))
   x = nn.Conv(self.channels, (3, 3), padding='VALID', kernel_dilation=(_d, _d), use_bias=False, name='conv1')(x)
   x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
   x = nn.relu(x)
   return x
Ejemplo n.º 29
0
 def __call__(self, x, train):
     maybe_normalize = model_utils.get_normalizer(self.normalizer, train)
     iterator = zip(self.num_filters, self.kernel_sizes,
                    self.kernel_paddings, self.window_sizes,
                    self.window_paddings, self.strides)
     for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator:
         x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1),
                     padding=kernel_padding,
                     kernel_init=self.kernel_init,
                     bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
         x = nn.max_pool(x,
                         window_shape=(window_size, window_size),
                         strides=(stride, stride),
                         padding=window_padding)
     x = jnp.reshape(x, (x.shape[0], -1))
     for num_units in self.num_dense_units:
         x = nn.Dense(num_units,
                      kernel_init=self.kernel_init,
                      bias_init=self.bias_init)(x)
         x = model_utils.ACTIVATIONS[self.activation_fn](x)
         x = maybe_normalize()(x)
     x = nn.Dense(self.num_outputs,
                  kernel_init=self.kernel_init,
                  bias_init=self.bias_init)(x)
     return x
Ejemplo n.º 30
0
 def _upsample(self, x, name):
     B, H, W, C = x.shape  # pylint: disable=invalid-name
     x = nearest_neighbor_upsample(x)
     x = nn.Conv(features=C, kernel_size=(3, 3), strides=(1, 1),
                 name=name)(x)
     assert x.shape == (B, H * 2, W * 2, C)
     return x