def __call__(self, inputs): x = inputs input_filters = x.shape[-1] # Expand (block_id controls this block in the keras implementation). x = nn.Conv(_depth(input_filters * self.expansion), kernel_size=(1, 1), use_bias=False)(x) if self.batch_norm: x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x) x = self.activation(x) if self.stride == 2: x = zero_pad_2d(correct_pad(x, self.kernel_size))(x) x = DepthwiseConv2D(kernel_size=(self.kernel_size, self.kernel_size), strides=(self.stride, self.stride), padding="same" if self.stride == 1 else "valid", use_bias=False)(x) if self.batch_norm: x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x) x = self.activation(x) if self.se_ratio: x = SEBlock(self.se_ratio)(x) x = nn.Conv(self.filters, kernel_size=(1, 1), use_bias=False)(x) if self.batch_norm: x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x) if self.stride == 1 and input_filters == self.filters: x = inputs + x return x
def __call__(self, z, train: bool = True): # Common arguments conv_kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } norm_kwargs = { 'use_running_average': not train, 'momentum': 0.99, 'epsilon': 0.001, 'use_scale': True, 'use_bias': True } z = np.reshape(z, (1, 1, self.zdim)) # Layer 1 z = nn.ConvTranspose(features=512, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 2 z = nn.ConvTranspose(features=256, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 3 z = nn.ConvTranspose(features=128, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 4 z = nn.ConvTranspose(features=64, **conv_kwargs)(z) z = nn.BatchNorm(**norm_kwargs)(z) z = nn.leaky_relu(z, 0.2) # Layer 5 z = nn.ConvTranspose(features=1, kernel_size=(4, 4), strides=(2, 2), padding='SAME', use_bias=False, kernel_init=nn.initializers.xavier_normal())(z) # x = nn.sigmoid(z) x = nn.softplus(z) return jnp.rot90(np.squeeze(x), k=2) # Rotate to match TF output
def __call__(self, x, train: bool = True): # Common arguments kwargs = { 'kernel_size': (4, 4), 'strides': (2, 2), 'padding': 'SAME', 'use_bias': False, 'kernel_init': he_normal() } # x = np.reshape(x, (64, 64, 1)) x = x[..., None] # Layer 1 x = nn.Conv(features=64, **kwargs)(x) x = nn.leaky_relu(x, 0.2) # Layer 2 x = nn.Conv(features=128, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 3 x = nn.Conv(features=256, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 4 x = nn.Conv(features=512, **kwargs)(x) x = nn.BatchNorm(use_running_average=not train)(x) x = nn.leaky_relu(x, 0.2) # Layer 5 x = nn.Conv(features=4096, kernel_size=(4, 4), strides=(1, 1), padding='VALID', use_bias=False, kernel_init=he_normal())(x) x = nn.leaky_relu(x, 0.2) # Flatten x = x.flatten() # Predict latent variables z_mean = nn.Dense(features=self.zdim)(x) z_logvar = nn.Dense(features=self.zdim)(x) return z_mean, z_logvar
def setup(self): self.theta = nn.Dense(self.out_feat) self.phi = nn.Dense(self.out_feat) if self.batch_norm: self.bn = nn.BatchNorm()
def __call__(self, x: jnp.ndarray, training: bool) -> jnp.ndarray: # Normalize the input x = x.astype(jnp.float32) / 255.0 # Block 1 x = linen.Conv(32, [3, 3], strides=[2, 2])(x) x = linen.Dropout(0.05, deterministic=not training)(x) x = jax.nn.relu(x) # Block 2 x = linen.Conv(64, [3, 3], strides=[2, 2])(x) x = linen.BatchNorm(use_running_average=not training)(x) x = linen.Dropout(0.1, deterministic=not training)(x) x = jax.nn.relu(x) # Block 3 x = linen.Conv(128, [3, 3], strides=[2, 2])(x) # Global average pooling x = x.mean(axis=(1, 2)) # Classification layer x = linen.Dense(10)(x) return x
def __call__(self, inputs, train: bool = False): x = ASPP([12, 24, 36], name='ASPP')(inputs) x = nn.Conv(256, (3, 3), padding='SAME', use_bias=False, name="conv1")(x) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) x = nn.relu(x) x = nn.Conv(self.num_classes, (1, 1), padding='VALID', use_bias=True, name="conv2")(x) return x
def __call__(self, inputs, is_training: bool): x = nn.Conv(features=self.num_ch, use_bias=self.use_bias, kernel_size=(self.conv_kernel_size, self.conv_kernel_size), strides=(self.conv_stride, self.conv_stride), padding=[(self.patch_shape[0], ) * 2, (self.patch_shape[1], ) * 2])(inputs) x = nn.BatchNorm(use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype)(x) x = nn.max_pool( inputs=x, window_shape=(self.pool_window_size, ) * 2, strides=(self.pool_stride, ) * 2, ) x = rearrange( x, 'b (h ph) (w pw) c -> b (h w) (ph pw c)', ph=self.patch_shape[0], pw=self.patch_shape[1], ) output = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return output
def __call__(self, inputs, train: bool = False): _d = max(1, self.dilation) x = jnp.pad(inputs, [(0, 0), (_d, _d), (_d, _d), (0, 0)], 'constant', (0, 0)) x = nn.Conv(self.channels, (3, 3), padding='VALID', kernel_dilation=(_d, _d), use_bias=False, name='conv1')(x) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) x = nn.relu(x) return x
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )(inputs) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock(filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output
def __call__(self, x): norm = nn.BatchNorm( name="norm", use_running_average=False, axis_name="batch", ) x = norm(x) return x, norm(x)
def __call__(self, x): x = nn.Conv(self.filters, (1, 1), use_bias=False, dtype=jnp.float32)(x) x = nn.BatchNorm( use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=jnp.float32)( x) return x
def __call__(self, inputs, train: bool = False): inter_channels = np.shape(inputs)[-1] // 4 x = nn.Conv(inter_channels, (3, 3), padding='SAME', use_bias=False, name="conv1")(inputs) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) x = nn.relu(x) x = nn.Dropout(0.1)(x, deterministic=not train) x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=True, name="conv2")(x) return x
def test_batch_norm(self): """Test virtual BN recovers BN when the virtual size equals batch size.""" rng = jax.random.PRNGKey(0) batch_size = 10 feature_size = 7 input_shape = (batch_size, 3, 3, feature_size) half_input_shape = (batch_size // 2, 3, 3, feature_size) twos = 2.0 * jnp.ones(half_input_shape) nines = 9.0 * jnp.ones(half_input_shape) x = jnp.concatenate((twos, nines)) bn_flax_module = nn.BatchNorm(momentum=0.9) bn_params, bn_state = _init(bn_flax_module, rng, input_shape) vbn_flax_module = normalization.VirtualBatchNorm( momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) _, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], use_running_average=False) bn_state = bn_state['batch_stats'] bn_y, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], use_running_average=False) bn_state = bn_state['batch_stats'] _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] vbn_y, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], use_running_average=False) vbn_state = vbn_state['batch_stats'] # Test that the layer forward passes are the same. np.testing.assert_allclose(bn_y, vbn_y, atol=1e-4) # Test that virtual and regular BN produce the same EMAs. np.testing.assert_allclose( bn_state['mean'], np.squeeze(vbn_state['batch_norm_running_mean']), atol=1e-4) np.testing.assert_allclose( bn_state['var'], np.squeeze(vbn_state['batch_norm_running_var']), atol=1e-4)
def __call__(self, inputs, train: bool = False): res = [] x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=False, name="conv1")(inputs) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) res.append(nn.relu(x)) for i, rate in enumerate(self.atrous_rates): res.append(ASPPConv(self.channels, rate, name=f'ASPPConv{i+1}')(inputs)) res.append(ASPPPooling(self.channels, name='ASPPPooling')(inputs)) x = jnp.concatenate(res, -1) # 1280 x = nn.Conv(self.channels, (1, 1), padding='VALID', use_bias=False, name="conv2")(x) x = nn.BatchNorm(use_running_average=not train, name="bn2")(x) x = nn.relu(x) x = nn.Dropout(0.5)(x, deterministic=not train) return x
def __call__(self, inputs, is_training): h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(inputs) h = nn.Dense(self.vocab_size, use_bias=False)(h) return nn.BatchNorm( use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training, )(h)
def __call__(self, inputs, train: bool = False): in_shape = np.shape(inputs)[1:-1] x = nn.avg_pool(inputs, in_shape) x = nn.Conv(self.channels, (1, 1), padding='SAME', use_bias=False, name="conv1")(x) x = nn.BatchNorm(use_running_average=not train, name="bn1")(x) x = nn.relu(x) out_shape = (1, in_shape[0], in_shape[1], self.channels) x = jax.image.resize(x, shape=out_shape, method='bilinear') return x
def setup(self): activation = nn.softplus if self.activation == 'softplus' else nn.relu if (self.group_norm): self.double_conv = Sequential([ nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False), nn.GroupNorm(self.num_groups), activation, nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False), nn.GroupNorm(self.num_groups), activation, ]) else: self.double_conv = Sequential([ nn.Conv(self.mid_channels, kernel_size=(3, 3), use_bias=False), nn.BatchNorm(use_running_average=self.test), activation, nn.Conv(self.out_channels, kernel_size=(3, 3), use_bias=False), nn.BatchNorm(use_running_average=self.test), activation, ])
def __call__(self, x, train=False): for v in self.cfg: if v == 'M': x = nn.max_pool(x, (2, 2), (2, 2)) else: x = nn.Conv(v, (3, 3), padding='SAME', dtype=self.dtype)(x) if self.batch_norm: x = nn.BatchNorm(use_running_average=not train, momentum=0.1, dtype=self.dtype)(x) x = nn.relu(x) return x
def __call__(self, x): x = nn.Dense(10)(x) if dropout: x = nn.Dropout(0.5, deterministic=False)(x) if batchnorm: x = nn.BatchNorm( use_bias=True, use_scale=True, momentum=0.999, use_running_average=False, )(x) return x
def __call__(self, inputs, is_training): h = nn.softplus(nn.Dense(self.hidden)(inputs)) h = nn.softplus(nn.Dense(self.hidden)(h)) h = nn.Dropout(self.dropout_rate, deterministic=not is_training)(h) h = nn.Dense(self.num_topics)(h) log_concentration = nn.BatchNorm( use_bias=False, use_scale=False, momentum=0.9, use_running_average=not is_training, )(h) return jnp.exp(log_concentration)
def __call__(self, x, train): x = nn.BatchNorm(use_running_average=not train, momentum=0.9, epsilon=1e-5, name='init_bn', axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, dtype=self.dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype)(x) return x
def __call__(self, x): x = nn.Conv(self.out_channels, kernel_size=self.kernel_size, strides=self.strides, padding=self.padding, use_bias=False, name='conv', dtype=self.dtype)(x) x = nn.BatchNorm(use_running_average=not self.train, epsilon=0.001, name='bn', dtype=self.dtype)(x) return nn.relu(x)
def __call__(self, x): if self.minimalistic: kernel = 3 activation = relu se_ratio = None else: kernel = 5 activation = hard_swish se_ratio = 0.25 # Input processing (shared between small and large variants). x = x / 255 x = nn.Conv(features=16, kernel_size=(3, 3), strides=(2, 2), use_bias=False)(x) x = activation(x) # Main network get_args = _get_large_args if self.large else _get_small_args for args in get_args(kernel, activation, se_ratio, self.alpha): x = ResidualInvertedBottleneck(*args, batch_norm=self.batch_norm)(x) # Last stages (shared between small and large variants). x = nn.Conv(features=_depth(x.shape[-1] * 6), kernel_size=(1, 1), use_bias=False)(x) if self.batch_norm: x = nn.BatchNorm(epsilon=1e-3, momentum=0.999)(x) x = activation(x) if self.alpha > 1.0: last_point_features = _depth(self.last_point_features * self.alpha) else: last_point_features = self.last_point_features x = nn.Conv(features=last_point_features, kernel_size=(1, 1), use_bias=True)(x) x = activation(x) x = global_average_pooling(x) x = x.reshape((x.shape[0], 1, 1, last_point_features)) x = nn.Conv(features=self.classes, kernel_size=(1, 1))(x) x = flatten(x) # x = self.classifier_activation(x) return x
def __call__(self, x): residual = x y = nn.Conv( features=self.num_filters * 4, kernel_size=(1, 1), name="b1_Conv_2", use_bias=False, dtype=self.dtype, )(x) residual = nn.Conv( features=self.num_filters * 4, kernel_size=(1, 1), strides=(1, 1), name="b1_conv_proj", use_bias=False, dtype=self.dtype, )(residual) residual = nn.BatchNorm( name="b1_norm_proj", use_running_average=False, momentum=0.9, epsilon=1e-5, dtype=self.dtype, )(residual) x = residual + y residual = x y = nn.Conv( features=self.num_filters * 4, kernel_size=(3, 3), strides=(1, 1), name="b2_Conv_1", use_bias=False, dtype=self.dtype, )(x) var_tracker = self.variable("zzz_grad_stats", "b2_conv_proj_dummy", jnp.zeros, (residual.shape)) residual = residual + var_tracker.value residual = nn.Conv( features=self.num_filters * 4, kernel_size=(1, 1), strides=(1, 1), name="b2_conv_proj", use_bias=False, dtype=self.dtype, )(residual) x = residual + y return x
def batchnorm2d( eps=1e-3, momentum=0.99, affine=True, training=True, name: Optional[str] = None, bias_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.zeros, weight_init: Callable[[PRNGKey, Shape, Dtype], Array] = initializers.ones, ): return nn.BatchNorm( use_running_average=not training, momentum=momentum, epsilon=eps, use_bias=affine, use_scale=affine, name=name, bias_init=bias_init, scale_init=weight_init, )
def __call__(self, inputs, is_training: bool): in_ch = inputs.shape[-1] conv = partial(nn.Conv, dtype=self.dtype) x = conv(features=in_ch, kernel_size=(self.kernel_size, self.kernel_size), strides=(self.strides, self.strides), padding='SAME', feature_group_count=in_ch, use_bias=False)(inputs) x = nn.BatchNorm(use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype)(x) output = conv(features=self.out_ch, kernel_size=(1, 1), use_bias=self.use_bias)(x) return output
def test_batch_norm(self): rng = random.PRNGKey(0) key1, key2 = random.split(rng) x = random.normal(key1, (4, 3, 2)) model_cls = nn.BatchNorm(momentum=0.9) y, initial_params = model_cls.init_with_output(key2, x) mean = y.mean((0, 1)) var = y.var((0, 1)) np.testing.assert_allclose(mean, np.array([0., 0.]), atol=1e-4) np.testing.assert_allclose(var, np.array([1., 1.]), rtol=1e-4) y, vars_out = model_cls.apply(initial_params, x, mutable=['batch_stats']) ema = vars_out['batch_stats'] np.testing.assert_allclose( ema['mean'], 0.1 * x.mean((0, 1), keepdims=False), atol=1e-4) np.testing.assert_allclose( ema['var'], 0.9 + 0.1 * x.var((0, 1), keepdims=False), rtol=1e-4)
def __call__(self, inputs, train): """Applies the network to inputs. Args: inputs: (batch_size, resolution, resolution, n_spins, n_channels) array. train: whether to run in training or inference mode. Returns: A (batch_size, num_classes) float32 array with per-class scores (logits). Raises: ValueError: If resolutions cannot be enforced with 2x2 pooling. """ num_layers = len(self.resolutions) # Merge spin and channel dimensions. features = inputs.reshape((*inputs.shape[:3], -1)) for layer_id in range(num_layers - 1): resolution_in = self.resolutions[layer_id] resolution_out = self.resolutions[layer_id + 1] n_channels = self.widths[layer_id + 1] if resolution_out == resolution_in // 2: features = nn.avg_pool(features, window_shape=(2, 2), strides=(2, 2), padding='SAME') elif resolution_out != resolution_in: raise ValueError( 'Consecutive resolutions must be equal or halved.') features = nn.Conv(features=n_channels, kernel_size=(3, 3), strides=(1, 1))(features) features = nn.BatchNorm(use_running_average=not train, axis_name=self.axis_name)(features) features = nn.relu(features) features = jnp.mean(features, axis=(1, 2)) features = nn.Dense(self.num_classes)(features) return features
def setup(self): self.a = nn.Dense(3) self.bn = nn.BatchNorm() self.b = nn.Dense(1)
def __call__(self, x): h = nn.Dense(1)(x) h = nn.BatchNorm()(h) return nn.Dense(x.shape[-1])(x)