def __call__(self, inputs):
        x = inputs
        input_filters = x.shape[-1]

        # Expand (block_id controls this block in the keras implementation).
        x = nn.Conv(_depth(input_filters * self.expansion),
                    kernel_size=(1, 1),
                    use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = self.activation(x)

        if self.stride == 2:
            x = zero_pad_2d(correct_pad(x, self.kernel_size))(x)
        x = DepthwiseConv2D(kernel_size=(self.kernel_size, self.kernel_size),
                            strides=(self.stride, self.stride),
                            padding="same" if self.stride == 1 else "valid",
                            use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = self.activation(x)

        if self.se_ratio:
            x = SEBlock(self.se_ratio)(x)

        x = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)

        if self.stride == 1 and input_filters == self.filters:
            x = inputs + x
        return x
Пример #2
0
    def __call__(self, z, train: bool = True):
        # Common arguments
        conv_kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }
        norm_kwargs = {
            'use_running_average': not train,
            'momentum': 0.99,
            'epsilon': 0.001,
            'use_scale': True,
            'use_bias': True
        }

        z = np.reshape(z, (1, 1, self.zdim))

        # Layer 1
        z = nn.ConvTranspose(features=512,
                             kernel_size=(4, 4),
                             strides=(1, 1),
                             padding='VALID',
                             use_bias=False,
                             kernel_init=he_normal())(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 2
        z = nn.ConvTranspose(features=256, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 3
        z = nn.ConvTranspose(features=128, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 4
        z = nn.ConvTranspose(features=64, **conv_kwargs)(z)
        z = nn.BatchNorm(**norm_kwargs)(z)
        z = nn.leaky_relu(z, 0.2)

        # Layer 5
        z = nn.ConvTranspose(features=1,
                             kernel_size=(4, 4),
                             strides=(2, 2),
                             padding='SAME',
                             use_bias=False,
                             kernel_init=nn.initializers.xavier_normal())(z)
        # x = nn.sigmoid(z)
        x = nn.softplus(z)

        return jnp.rot90(np.squeeze(x), k=2)  # Rotate to match TF output
Пример #3
0
    def __call__(self, x, train: bool = True):
        # Common arguments
        kwargs = {
            'kernel_size': (4, 4),
            'strides': (2, 2),
            'padding': 'SAME',
            'use_bias': False,
            'kernel_init': he_normal()
        }

        # x = np.reshape(x, (64, 64, 1))
        x = x[..., None]

        # Layer 1
        x = nn.Conv(features=64, **kwargs)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 2
        x = nn.Conv(features=128, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 3
        x = nn.Conv(features=256, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 4
        x = nn.Conv(features=512, **kwargs)(x)
        x = nn.BatchNorm(use_running_average=not train)(x)
        x = nn.leaky_relu(x, 0.2)

        # Layer 5
        x = nn.Conv(features=4096,
                    kernel_size=(4, 4),
                    strides=(1, 1),
                    padding='VALID',
                    use_bias=False,
                    kernel_init=he_normal())(x)
        x = nn.leaky_relu(x, 0.2)

        # Flatten
        x = x.flatten()

        # Predict latent variables
        z_mean = nn.Dense(features=self.zdim)(x)
        z_logvar = nn.Dense(features=self.zdim)(x)

        return z_mean, z_logvar
Пример #4
0
    def setup(self):

        self.theta = nn.Dense(self.out_feat)
        self.phi = nn.Dense(self.out_feat)

        if self.batch_norm:
            self.bn = nn.BatchNorm()
Пример #5
0
    def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray:
        # Normalize the input
        x = x.astype(jnp.float32) / 255.0

        # Block 1
        x = linen.Conv(32, [3, 3], strides=[2, 2])(x)
        x = linen.Dropout(0.05, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 2
        x = linen.Conv(64, [3, 3], strides=[2, 2])(x)
        x = linen.BatchNorm(use_running_average=not training)(x)
        x = linen.Dropout(0.1, deterministic=not training)(x)
        x = jax.nn.relu(x)

        # Block 3
        x = linen.Conv(128, [3, 3], strides=[2, 2])(x)

        # Global average pooling
        x = x.mean(axis=(1, 2))

        # Classification layer
        x = linen.Dense(10)(x)

        return x
Пример #6
0
 def __call__(self, inputs, train: bool = False):
   x = ASPP([12, 24, 36], name='ASPP')(inputs)
   x = nn.Conv(256, (3, 3), padding='SAME', use_bias=False, name="conv1")(x)
   x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
   x = nn.relu(x)
   x = nn.Conv(self.num_classes, (1, 1), padding='VALID', use_bias=True, name="conv2")(x)
   return x
Пример #7
0
    def __call__(self, inputs, is_training: bool):

        x = nn.Conv(features=self.num_ch,
                    use_bias=self.use_bias,
                    kernel_size=(self.conv_kernel_size, self.conv_kernel_size),
                    strides=(self.conv_stride, self.conv_stride),
                    padding=[(self.patch_shape[0], ) * 2,
                             (self.patch_shape[1], ) * 2])(inputs)
        x = nn.BatchNorm(use_running_average=not is_training,
                         momentum=self.bn_momentum,
                         epsilon=self.bn_epsilon,
                         dtype=self.dtype)(x)
        x = nn.max_pool(
            inputs=x,
            window_shape=(self.pool_window_size, ) * 2,
            strides=(self.pool_stride, ) * 2,
        )
        x = rearrange(
            x,
            'b (h ph) (w pw) c -> b (h w) (ph pw c)',
            ph=self.patch_shape[0],
            pw=self.patch_shape[1],
        )

        output = nn.Dense(features=self.embed_dim,
                          use_bias=self.use_bias,
                          dtype=self.dtype,
                          precision=self.precision,
                          kernel_init=self.kernel_init,
                          bias_init=self.bias_init)(x)
        return output
Пример #8
0
 def __call__(self, inputs, train: bool = False):
   _d = max(1, self.dilation)
   x = jnp.pad(inputs, [(0, 0), (_d, _d), (_d, _d), (0, 0)], 'constant', (0, 0))
   x = nn.Conv(self.channels, (3, 3), padding='VALID', kernel_dilation=(_d, _d), use_bias=False, name='conv1')(x)
   x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
   x = nn.relu(x)
   return x
Пример #9
0
    def __call__(
        self,
        inputs,
    ):
        """Applies ResNet model. Number of residual blocks inferred from hparams."""
        num_classes = self.num_classes
        hparams = self.hparams
        num_filters = self.num_filters
        dtype = self.dtype

        x = aqt_flax_layers.ConvAqt(
            features=num_filters,
            kernel_size=(7, 7),
            strides=(2, 2),
            padding=[(3, 3), (3, 3)],
            use_bias=False,
            dtype=dtype,
            name='init_conv',
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.conv_init,
        )(inputs)
        x = nn.BatchNorm(use_running_average=not self.train,
                         momentum=0.9,
                         epsilon=1e-5,
                         dtype=dtype,
                         name='init_bn')(x)
        x = nn.relu(x)
        x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
        filter_multiplier = hparams.filter_multiplier
        for i, block_hparams in enumerate(hparams.residual_blocks):
            proj = block_hparams.conv_proj
            # For projection layers (unless it is the first layer), strides = (2, 2)
            if i > 0 and proj is not None:
                filter_multiplier *= 2
                strides = (2, 2)
            else:
                strides = (1, 1)
            x = ResidualBlock(filters=int(num_filters * filter_multiplier),
                              hparams=block_hparams,
                              quant_context=self.quant_context,
                              strides=strides,
                              train=self.train,
                              dtype=dtype)(x)
        x = jnp.mean(x, axis=(1, 2))

        x = aqt_flax_layers.DenseAqt(
            features=num_classes,
            dtype=dtype,
            train=self.train,
            quant_context=self.quant_context,
            paxis_name='batch',
            hparams=hparams.dense_layer,
        )(x, padding_mask=None)

        x = jnp.asarray(x, dtype)
        output = nn.log_softmax(x)
        return output
Пример #10
0
 def __call__(self, x):
     norm = nn.BatchNorm(
         name="norm",
         use_running_average=False,
         axis_name="batch",
     )
     x = norm(x)
     return x, norm(x)
Пример #11
0
 def __call__(self, x):
   x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x)
   x = nn.BatchNorm(
       use_running_average=not self.train,
       momentum=0.9,
       epsilon=1e-5,
       dtype=jnp.float32)(
           x)
   return x
Пример #12
0
  def __call__(self, inputs, train: bool = False):
    inter_channels = np.shape(inputs)[-1] // 4
    x = nn.Conv(inter_channels, (3, 3), padding='SAME', use_bias=False, name="conv1")(inputs)
    x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
    x = nn.relu(x)
    x = nn.Dropout(0.1)(x, deterministic=not train)
    x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=True, name="conv2")(x)

    return x
Пример #13
0
  def test_batch_norm(self):
    """Test virtual BN recovers BN when the virtual size equals batch size."""
    rng = jax.random.PRNGKey(0)
    batch_size = 10
    feature_size = 7
    input_shape = (batch_size, 3, 3, feature_size)
    half_input_shape = (batch_size // 2, 3, 3, feature_size)
    twos = 2.0 * jnp.ones(half_input_shape)
    nines = 9.0 * jnp.ones(half_input_shape)
    x = jnp.concatenate((twos, nines))

    bn_flax_module = nn.BatchNorm(momentum=0.9)
    bn_params, bn_state = _init(bn_flax_module, rng, input_shape)

    vbn_flax_module = normalization.VirtualBatchNorm(
        momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC')
    vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape)

    _, bn_state = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    bn_state = bn_state['batch_stats']
    bn_y, bn_state = bn_flax_module.apply(
        {'params': bn_params, 'batch_stats': bn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    bn_state = bn_state['batch_stats']

    _, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']
    vbn_y, vbn_state = vbn_flax_module.apply(
        {'params': vbn_params, 'batch_stats': vbn_state},
        x,
        mutable=['batch_stats'],
        use_running_average=False)
    vbn_state = vbn_state['batch_stats']

    # Test that the layer forward passes are the same.
    np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4)

    # Test that virtual and regular BN produce the same EMAs.
    np.testing.assert_allclose(
        bn_state['mean'],
        np.squeeze(vbn_state['batch_norm_running_mean']),
        atol=1e-4)
    np.testing.assert_allclose(
        bn_state['var'],
        np.squeeze(vbn_state['batch_norm_running_var']),
        atol=1e-4)
Пример #14
0
  def __call__(self, inputs, train: bool = False):
    res = []

    x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=False, name="conv1")(inputs)
    x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
    res.append(nn.relu(x))

    for i, rate in enumerate(self.atrous_rates):
      res.append(ASPPConv(self.channels, rate, name=f'ASPPConv{i+1}')(inputs))

    res.append(ASPPPooling(self.channels, name='ASPPPooling')(inputs))
    x = jnp.concatenate(res, -1)  # 1280

    x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=False, name="conv2")(x)
    x = nn.BatchNorm(use_running_average=not train, name="bn2")(x)
    x = nn.relu(x)
    x = nn.Dropout(0.5)(x, deterministic=not train)

    return x
Пример #15
0
 def __call__(self, inputs, is_training):
     h = nn.Dropout(self.dropout_rate,
                    deterministic=not is_training)(inputs)
     h = nn.Dense(self.vocab_size, use_bias=False)(h)
     return nn.BatchNorm(
         use_bias=False,
         use_scale=False,
         momentum=0.9,
         use_running_average=not is_training,
     )(h)
Пример #16
0
  def __call__(self, inputs, train: bool = False):
    in_shape = np.shape(inputs)[1:-1]
    x = nn.avg_pool(inputs, in_shape)
    x = nn.Conv(self.channels, (1, 1), padding='SAME', use_bias=False, name="conv1")(x)
    x = nn.BatchNorm(use_running_average=not train, name="bn1")(x)
    x = nn.relu(x)

    out_shape = (1, in_shape[0], in_shape[1], self.channels)
    x = jax.image.resize(x, shape=out_shape, method='bilinear')

    return x
Пример #17
0
 def setup(self):
     activation = nn.softplus if self.activation == 'softplus' else nn.relu
     if (self.group_norm):
         self.double_conv = Sequential([
             nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False),
             nn.GroupNorm(self.num_groups),
             activation,
             nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False),
             nn.GroupNorm(self.num_groups),
             activation,
         ])
     else:
         self.double_conv = Sequential([
             nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False),
             nn.BatchNorm(use_running_average=self.test),
             activation,
             nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False),
             nn.BatchNorm(use_running_average=self.test),
             activation,
         ])
Пример #18
0
 def __call__(self, x, train=False):
     for v in self.cfg:
         if v == 'M':
             x = nn.max_pool(x, (2, 2), (2, 2))
         else:
             x = nn.Conv(v, (3, 3), padding='SAME', dtype=self.dtype)(x)
             if self.batch_norm:
                 x = nn.BatchNorm(use_running_average=not train,
                                  momentum=0.1,
                                  dtype=self.dtype)(x)
             x = nn.relu(x)
     return x
Пример #19
0
 def __call__(self, x):
     x = nn.Dense(10)(x)
     if dropout:
         x = nn.Dropout(0.5, deterministic=False)(x)
     if batchnorm:
         x = nn.BatchNorm(
             use_bias=True,
             use_scale=True,
             momentum=0.999,
             use_running_average=False,
         )(x)
     return x
Пример #20
0
    def __call__(self, inputs, is_training):
        h = nn.softplus(nn.Dense(self.hidden)(inputs))
        h = nn.softplus(nn.Dense(self.hidden)(h))
        h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h)
        h = nn.Dense(self.num_topics)(h)

        log_concentration = nn.BatchNorm(
            use_bias=False,
            use_scale=False,
            momentum=0.9,
            use_running_average=not is_training,
        )(h)
        return jnp.exp(log_concentration)
Пример #21
0
 def __call__(self, x, train):
     x = nn.BatchNorm(use_running_average=not train,
                      momentum=0.9,
                      epsilon=1e-5,
                      name='init_bn',
                      axis_name=self.axis_name,
                      axis_index_groups=self.axis_index_groups,
                      dtype=self.dtype)(x)
     x = jnp.mean(x, axis=(1, 2))
     x = nn.Dense(self.num_classes,
                  kernel_init=nn.initializers.normal(),
                  dtype=self.dtype)(x)
     return x
Пример #22
0
 def __call__(self, x):
     x = nn.Conv(self.out_channels,
                 kernel_size=self.kernel_size,
                 strides=self.strides,
                 padding=self.padding,
                 use_bias=False,
                 name='conv',
                 dtype=self.dtype)(x)
     x = nn.BatchNorm(use_running_average=not self.train,
                      epsilon=0.001,
                      name='bn',
                      dtype=self.dtype)(x)
     return nn.relu(x)
    def __call__(self, x):

        if self.minimalistic:
            kernel = 3
            activation = relu
            se_ratio = None
        else:
            kernel = 5
            activation = hard_swish
            se_ratio = 0.25

        # Input processing (shared between small and large variants).
        x = x / 255
        x = nn.Conv(features=16,
                    kernel_size=(3, 3),
                    strides=(2, 2),
                    use_bias=False)(x)
        x = activation(x)

        # Main network
        get_args = _get_large_args if self.large else _get_small_args
        for args in get_args(kernel, activation, se_ratio, self.alpha):
            x = ResidualInvertedBottleneck(*args,
                                           batch_norm=self.batch_norm)(x)

        # Last stages (shared between small and large variants).
        x = nn.Conv(features=_depth(x.shape[-1] * 6),
                    kernel_size=(1, 1),
                    use_bias=False)(x)
        if self.batch_norm:
            x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x)
        x = activation(x)

        if self.alpha > 1.0:
            last_point_features = _depth(self.last_point_features * self.alpha)
        else:
            last_point_features = self.last_point_features
        x = nn.Conv(features=last_point_features,
                    kernel_size=(1, 1),
                    use_bias=True)(x)
        x = activation(x)

        x = global_average_pooling(x)
        x = x.reshape((x.shape[0], 1, 1, last_point_features))

        x = nn.Conv(features=self.classes, kernel_size=(1, 1))(x)
        x = flatten(x)
        # x = self.classifier_activation(x)

        return x
Пример #24
0
 def __call__(self, x):
     residual = x
     y = nn.Conv(
         features=self.num_filters * 4,
         kernel_size=(1, 1),
         name="b1_Conv_2",
         use_bias=False,
         dtype=self.dtype,
     )(x)
     residual = nn.Conv(
         features=self.num_filters * 4,
         kernel_size=(1, 1),
         strides=(1, 1),
         name="b1_conv_proj",
         use_bias=False,
         dtype=self.dtype,
     )(residual)
     residual = nn.BatchNorm(
         name="b1_norm_proj",
         use_running_average=False,
         momentum=0.9,
         epsilon=1e-5,
         dtype=self.dtype,
     )(residual)
     x = residual + y
     residual = x
     y = nn.Conv(
         features=self.num_filters * 4,
         kernel_size=(3, 3),
         strides=(1, 1),
         name="b2_Conv_1",
         use_bias=False,
         dtype=self.dtype,
     )(x)
     var_tracker = self.variable("zzz_grad_stats", "b2_conv_proj_dummy",
                                 jnp.zeros, (residual.shape))
     residual = residual + var_tracker.value
     residual = nn.Conv(
         features=self.num_filters * 4,
         kernel_size=(1, 1),
         strides=(1, 1),
         name="b2_conv_proj",
         use_bias=False,
         dtype=self.dtype,
     )(residual)
     x = residual + y
     return x
Пример #25
0
def batchnorm2d(
        eps=1e-3,
        momentum=0.99,
        affine=True,
        training=True,
        name: Optional[str] = None,
        bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros,
        weight_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones,
):
    return nn.BatchNorm(
        use_running_average=not training,
        momentum=momentum,
        epsilon=eps,
        use_bias=affine,
        use_scale=affine,
        name=name,
        bias_init=bias_init,
        scale_init=weight_init,
    )
    def __call__(self, inputs, is_training: bool):
        in_ch = inputs.shape[-1]

        conv = partial(nn.Conv, dtype=self.dtype)

        x = conv(features=in_ch,
                 kernel_size=(self.kernel_size, self.kernel_size),
                 strides=(self.strides, self.strides),
                 padding='SAME',
                 feature_group_count=in_ch,
                 use_bias=False)(inputs)
        x = nn.BatchNorm(use_running_average=not is_training,
                         momentum=self.bn_momentum,
                         epsilon=self.bn_epsilon,
                         dtype=self.dtype)(x)
        output = conv(features=self.out_ch,
                      kernel_size=(1, 1),
                      use_bias=self.use_bias)(x)
        return output
Пример #27
0
  def test_batch_norm(self):
    rng = random.PRNGKey(0)
    key1, key2 = random.split(rng)
    x = random.normal(key1, (4, 3, 2))
    model_cls = nn.BatchNorm(momentum=0.9)
    y, initial_params = model_cls.init_with_output(key2, x)

    mean = y.mean((0, 1))
    var = y.var((0, 1))
    np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4)
    np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4)

    y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats'])

    ema = vars_out['batch_stats']
    np.testing.assert_allclose(
        ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4)
    np.testing.assert_allclose(
        ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4)
Пример #28
0
    def __call__(self, inputs, train):
        """Applies the network to inputs.

    Args:
      inputs: (batch_size, resolution, resolution, n_spins, n_channels) array.
      train: whether to run in training or inference mode.
    Returns:
      A (batch_size, num_classes) float32 array with per-class scores (logits).
    Raises:
      ValueError: If resolutions cannot be enforced with 2x2 pooling.
    """
        num_layers = len(self.resolutions)
        # Merge spin and channel dimensions.
        features = inputs.reshape((*inputs.shape[:3], -1))
        for layer_id in range(num_layers - 1):
            resolution_in = self.resolutions[layer_id]
            resolution_out = self.resolutions[layer_id + 1]
            n_channels = self.widths[layer_id + 1]

            if resolution_out == resolution_in // 2:
                features = nn.avg_pool(features,
                                       window_shape=(2, 2),
                                       strides=(2, 2),
                                       padding='SAME')
            elif resolution_out != resolution_in:
                raise ValueError(
                    'Consecutive resolutions must be equal or halved.')

            features = nn.Conv(features=n_channels,
                               kernel_size=(3, 3),
                               strides=(1, 1))(features)

            features = nn.BatchNorm(use_running_average=not train,
                                    axis_name=self.axis_name)(features)
            features = nn.relu(features)

        features = jnp.mean(features, axis=(1, 2))
        features = nn.Dense(self.num_classes)(features)

        return features
Пример #29
0
 def setup(self):
   self.a = nn.Dense(3)
   self.bn = nn.BatchNorm()
   self.b = nn.Dense(1)
Пример #30
0
 def __call__(self, x):
     h = nn.Dense(1)(x)
     h = nn.BatchNorm()(h)
     return nn.Dense(x.shape[-1])(x)