def __call__(self, x, train=False): conv_block = partial(BasicConv, train=train, dtype=self.dtype) inception_a = partial(InceptionA, conv_block=conv_block) inception_b = partial(InceptionB, conv_block=conv_block) inception_c = partial(InceptionC, conv_block=conv_block) inception_d = partial(InceptionD, conv_block=conv_block) inception_e = partial(InceptionE, conv_block=conv_block) inception_aux = partial(InceptionAux, conv_block=conv_block) if self.transform_input: x = np.transpose(x, (0, 3, 1, 2)) x_ch0 = jnp.expand_dims(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5 x_ch1 = jnp.expand_dims(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5 x_ch2 = jnp.expand_dims(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5 x = jnp.concatenate((x_ch0, x_ch1, x_ch2), 1) x = np.transpose(x, (0, 2, 3, 1)) x = conv_block(32, kernel_size=(3, 3), strides=(2, 2), name='Conv2d_1a_3x3')(x) x = conv_block(32, kernel_size=(3, 3), name='Conv2d_2a_3x3')(x) x = conv_block(64, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='Conv2d_2b_3x3')(x) x = nn.max_pool(x, (3, 3), strides=(2, 2)) x = conv_block(80, kernel_size=(1, 1), name='Conv2d_3b_1x1')(x) x = conv_block(192, kernel_size=(3, 3), name='Conv2d_4a_3x3')(x) x = nn.max_pool(x, (3, 3), strides=(2, 2)) x = inception_a(pool_features=32, name='Mixed_5b')(x) x = inception_a(pool_features=64, name='Mixed_5c')(x) x = inception_a(pool_features=64, name='Mixed_5d')(x) x = inception_b(name='Mixed_6a')(x) x = inception_c(channels_7x7=128, name='Mixed_6b')(x) x = inception_c(channels_7x7=160, name='Mixed_6c')(x) x = inception_c(channels_7x7=160, name='Mixed_6d')(x) x = inception_c(channels_7x7=192, name='Mixed_6e')(x) aux = inception_aux(self.num_classes, name='AuxLogits')(x) \ if train and self.aux_logits else None x = inception_d(name='Mixed_7a')(x) x = inception_e(name='Mixed_7b')(x) x = inception_e(name='Mixed_7c')(x) x = nn.avg_pool(x, (8, 8)) x = nn.Dropout(0.5)(x, deterministic=not train) return x, aux
def __call__(self, inputs, is_training: bool): x = nn.Conv(features=self.num_ch, use_bias=self.use_bias, kernel_size=(self.conv_kernel_size, self.conv_kernel_size), strides=(self.conv_stride, self.conv_stride), padding=[(self.patch_shape[0], ) * 2, (self.patch_shape[1], ) * 2])(inputs) x = nn.BatchNorm(use_running_average=not is_training, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype)(x) x = nn.max_pool( inputs=x, window_shape=(self.pool_window_size, ) * 2, strides=(self.pool_stride, ) * 2, ) x = rearrange( x, 'b (h ph) (w pw) c -> b (h w) (ph pw c)', ph=self.patch_shape[0], pw=self.patch_shape[1], ) output = nn.Dense(features=self.embed_dim, use_bias=self.use_bias, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return output
def setup(self): self.maxpool_conv = Sequential([ lambda x: nn.max_pool(x, (2, 2), (2, 2)), DoubleConv(self.in_channels, self.out_channels, self.out_channels, self.test, self.group_norm, self.num_groups, self.activation), ])
def __call__(self, x): initializer = nn.initializers.xavier_uniform() conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3), strides=1, kernel_init=initializer, padding='SAME')(x) if self.use_max_pooling: conv_out = nn.max_pool(conv_out, window_shape=(3, 3), padding='SAME', strides=(2, 2)) for _ in range(self.num_blocks): block_input = conv_out conv_out = nn.relu(conv_out) conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3), strides=1, padding='SAME')(conv_out) conv_out = nn.relu(conv_out) conv_out = nn.Conv(features=self.num_ch, kernel_size=(3, 3), strides=1, padding='SAME')(conv_out) conv_out += block_input return conv_out
def __call__(self, x): branch3x3 = self.conv_block(192, kernel_size=(1, 1), name='branch3x3_1')(x) branch3x3 = self.conv_block(320, kernel_size=(3, 3), strides=(2, 2), name='branch3x3_2')(branch3x3) branch7x7x3 = self.conv_block(192, kernel_size=(1, 1), name='branch7x7x3_1')(x) branch7x7x3 = self.conv_block(192, kernel_size=(1, 7), padding=[(0, 0), (3, 3)], name='branch7x7x3_2')(branch7x7x3) branch7x7x3 = self.conv_block(192, kernel_size=(7, 1), padding=[(3, 3), (0, 0)], name='branch7x7x3_3')(branch7x7x3) branch7x7x3 = self.conv_block(192, kernel_size=(3, 3), strides=(2, 2), name='branch7x7x3_4')(branch7x7x3) branch_pool = nn.max_pool(x, (3, 3), strides=(2, 2)) outputs = [branch3x3, branch7x7x3, branch_pool] return jnp.concatenate(outputs, 3)
def __call__(self, x, train: bool = True): conv = partial(nn.Conv, use_bias=False, dtype=self.dtype) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype) x = conv(self.num_filters, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name='conv_init')(x) x = norm(name='bn_init')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = self.block_cls(self.num_filters * 2**i, strides=strides, conv=conv, norm=norm, act=self.act)(x) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, dtype=self.dtype)(x) x = jnp.asarray(x, self.dtype) x = nn.log_softmax(x) return x
def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer(self.normalizer, train) iterator = zip(self.num_filters, self.kernel_sizes, self.kernel_paddings, self.window_sizes, self.window_paddings, self.strides) for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator: x = nn.Conv(num_filters, (kernel_size, kernel_size), (1, 1), padding=kernel_padding, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.max_pool(x, window_shape=(window_size, window_size), strides=(stride, stride), padding=window_padding) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in self.num_dense_units: x = nn.Dense(num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.Dense(self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) return x
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )(inputs) x = nn.BatchNorm(use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock(filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)(x) x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output
def __call__(self, inputs, train: bool = True): """Passes the input through the network. Arguments: inputs: [batch_size, height, width, channels] train: bool Returns: output: [batch_size, config.num_classes] """ conv = partial(nn.Conv, use_bias=False, dtype=self.dtype, precision=self.precision, kernel_init=self.kernel_init) norm = partial(nn.BatchNorm, use_running_average=not train, momentum=self.bn_momentum, epsilon=self.bn_epsilon, dtype=self.dtype) y = conv(self.initial_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)])(inputs) y = norm()(y) y = self.activation_fn(y) y = nn.max_pool(y, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(self.stage_sizes[:-1]): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) y = BottleneckResNetBlock( filters=self.initial_filters * 2**i, strides=strides, conv=conv, norm=norm, se_ratio=self.se_ratio, projection_factor=self.projection_factor, activation_fn=self.activation_fn, dtype=self.dtype, )(y) for j in range(self.stage_sizes[-1]): strides = (2, 2) if j == 0 and self.stride_one is False else (1, 1) y = BoTBlock(filters=self.initial_filters * 2**(i + 1), strides=strides, conv=conv, norm=norm, projection_factor=self.projection_factor, activation_fn=self.activation_fn)(y) y = jnp.mean(y, axis=(1, 2)) y = nn.Dense(self.num_classes, dtype=self.dtype, kernel_init=self.kernel_init, bias_init=self.bias_init)(y) y = jnp.asarray(y, dtype=self.dtype) return y
def __call__(self, x): for feat in self.conv_features: x = nn.Conv(feat, kernel_size=(3, 3))(x) x = nn.max_pool(x, window_shape=(2, 2)) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) for feat in self.mlp_features[:-1]: x = nn.relu(nn.Dense(feat)(x)) x = nn.Dense(self.mlp_features[-1])(x) return x
def __call__(self, x): x = self.act(x) path = x for _ in range(self.n_stages): path = nn.max_pool(path, window_shape=(5, 5), strides=(1, 1), padding='SAME') path = ncsn_conv3x3(path, self.features, stride=1, bias=False) x = path + x return x
def __call__(self, x, train=False): for v in self.cfg: if v == 'M': x = nn.max_pool(x, (2, 2), (2, 2)) else: x = nn.Conv(v, (3, 3), padding='SAME', dtype=self.dtype)(x) if self.batch_norm: x = nn.BatchNorm(use_running_average=not train, momentum=0.1, dtype=self.dtype)(x) x = nn.relu(x) return x
def __call__(self, x): x = nn.Conv(features=16, kernel_size=(3, 3), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=32, kernel_size=(3, 3), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.Conv(features=64, kernel_size=(3, 3), padding='SAME')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = x.reshape((x.shape[0], -1)) # flatten x = nn.Dense(features=128)(x) x = nn.relu(x) x = nn.Dense(features=NB_CLASSES)(x) x = nn.softmax(x) return x
def basic_module(self, x): x = nn.Conv(features=90, kernel_size=(9, 9), padding='VALID', dtype=jp.float64)(x) x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) x = nn.ConvTranspose(features=1, kernel_size=(2, 2), strides=(2, 2), dtype=jp.float64)(x) x = x.reshape(x.shape[0], -1) x = jp.prod(x, 1) return x
def __call__(self, x, train): if self.num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] x = nn.Conv(self.num_filters, (7, 7), (2, 2), use_bias=False, dtype=self.dtype, name='init_conv')(x) if self.use_bn: x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, dtype=self.dtype, name='init_bn', batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, use_running_average=not train) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') if self.block_type == 'post_activation': residual_block = ResidualBlock elif self.block_type == 'pre_activation': residual_block = PreActResidualBlock else: raise ValueError('Invalid Block Type: {}'.format(self.block_type)) index = 0 for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) index += 1 x = residual_block( self.num_filters * 2**i, strides=strides, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, batch_norm_epsilon=self.batch_norm_epsilon, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format, bn_relu_conv=self.bn_relu_conv, use_bn=self.use_bn, activation_function=self.activation_function)(x, train=train) x = jnp.mean(x, axis=(1, 2)) if self.dropout_rate > 0.0: x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = nn.Dense(self.num_outputs, dtype=self.dtype)(x) return x
def features(x, num_layers, normalizer, dtype, train): """Implements the feature extraction portion of the network.""" layers = _layer_size_options[num_layers] conv = functools.partial(nn.Conv, use_bias=False, dtype=dtype) maybe_normalize = model_utils.get_normalizer(normalizer, train) for l in layers: if l == 'M': x = nn.max_pool(x, window_shape=(2, 2), strides=(2, 2)) else: x = conv(features=l, kernel_size=(3, 3), padding=((1, 1), (1, 1)))(x) x = maybe_normalize()(x) x = nn.relu(x) return x
def test_max_pool(self): x = jnp.arange(9).reshape((1, 3, 3, 1)).astype(jnp.float32) pool = lambda x: nn.max_pool(x, (2, 2)) expected_y = jnp.array([ [4., 5.], [7., 8.], ]).reshape((1, 2, 2, 1)) y = pool(x) np.testing.assert_allclose(y, expected_y) y_grad = jax.grad(lambda x: pool(x).sum())(x) expected_grad = jnp.array([ [0., 0., 0.], [0., 1., 1.], [0., 1., 1.], ]).reshape((1, 3, 3, 1)) np.testing.assert_allclose(y_grad, expected_grad)
def __call__(self, inputs, train: bool = False): norm = functools.partial(nn.BatchNorm, use_running_average=not train, momentum=0.9, epsilon=1e-5, dtype=self.dtype) # replace 2x2 strides with dilated convs if self.use_dilation is None: self.use_dilation = [False, False, False] if len(self.use_dilation) != 3: raise ValueError("use_dilation should be None " "or a 3-element tuple, got {}".format(self.use_dilation)) x = nn.Conv(64, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=self.dtype, name='conv1')(inputs) x = norm(name='bn1')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding=[(1, 1), (1, 1)]) dilation = 1 for i, block_size in enumerate(self.layers): features = 64 * 2**i downsample = False previous_dilation = dilation strides = (2, 2) if i > 0 else (1, 1) if i > 0 and self.use_dilation[i - 1]: dilation *= strides[0] strides = (1, 1) block_expansion = 4 if "Bottleneck" in self.block.__name__ else 1 if strides != (1, 1) or x.shape[-1] != features * block_expansion: downsample = True kwargs = { 'features': features, 'strides': strides, 'downsample': downsample, 'groups': self.groups, 'dilation': previous_dilation, 'base_width': self.width_per_group, 'norm': norm, 'dtype': self.dtype, } x = Layer(self.block, block_size, dilation, kwargs, name=f'layer{i+1}')(x) return x
def __call__(self, x, train): del train encoder_keys = [ 'filter_sizes', 'kernel_sizes', 'kernel_paddings', 'window_sizes', 'window_paddings', 'strides', 'activations', ] if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1: raise ValueError( 'The elements in encoder dict do not have the same length.') decoder_keys = [ 'filter_sizes', 'kernel_sizes', 'window_sizes', 'paddings', 'activations', ] if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1: raise ValueError( 'The elements in decoder dict do not have the same length.') # encoder for i in range(len(self.encoder['filter_sizes'])): x = nn.Conv(self.encoder['filter_sizes'][i], self.encoder['kernel_sizes'][i], padding=self.encoder['kernel_paddings'][i])(x) x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x) x = nn.max_pool(x, self.encoder['window_sizes'][i], strides=self.encoder['strides'][i], padding=self.encoder['window_paddings'][i]) # decoder for i in range(len(self.decoder['filter_sizes'])): x = nn.ConvTranspose(self.decoder['filter_sizes'][i], self.decoder['kernel_sizes'][i], self.decoder['window_sizes'][i], padding=self.decoder['paddings'][i])(x) x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x) return x
def __call__(self, x, train): if self.num_layers not in _block_size_options: raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] conv = functools.partial(nn.Conv, padding=[(3, 3), (3, 3)]) x = conv(self.num_filters, kernel_size=(7, 7), strides=(2, 2), use_bias=False, dtype=self.dtype, name='conv0')(x) x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, name='init_bn', axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, dtype=self.dtype, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, use_running_average=not train) x = nn.relu(x) # MLPerf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock(self.num_filters * 2**i, strides=strides, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, batch_norm_epsilon=self.batch_norm_epsilon, bn_output_scale=self.bn_output_scale, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format)(x, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype)(x) return x
def setup(self): self.conv1 = nn.Conv(features=64, kernel_size=(5, 5), strides=(1, 1), padding=((0, 0), (0, 0)), use_bias=True) self.conv2 = nn.Conv(features=64, kernel_size=(5, 5), strides=(1, 1), padding=((0, 0), (0, 0)), use_bias=True) self.dense1 = nn.Dense(features=384) self.dense2 = nn.Dense(features=192) self.dense3 = nn.Dense(features=10) self.pool = lambda x: nn.max_pool( x, window_shape=(2, 2), strides=(2, 2), padding=((0, 0), (0, 0))) self.activation = nn.leaky_relu
def __call__(self, x): branch3x3 = self.conv_block(384, kernel_size=(3, 3), strides=(2, 2), name='branch3x3')(x) branch3x3dbl = self.conv_block(64, kernel_size=(1, 1), name='branch3x3dbl_1')(x) branch3x3dbl = self.conv_block(96, kernel_size=(3, 3), padding=[(1, 1), (1, 1)], name='branch3x3dbl_2')(branch3x3dbl) branch3x3dbl = self.conv_block(96, kernel_size=(3, 3), strides=(2, 2), name='branch3x3dbl_3')(branch3x3dbl) branch_pool = nn.max_pool(x, (3, 3), strides=(2, 2)) outputs = [branch3x3, branch3x3dbl, branch_pool] return jnp.concatenate(outputs, 3)
def __call__( self, x: jnp.ndarray, train: bool = True, debug: bool = False) -> Union[jnp.ndarray, Dict[str, jnp.ndarray]]: """Applies the Bit ResNet model to the inputs. Args: x: Inputs to the model. train: Unused. debug: Unused. Returns: Un-normalized logits if `num_outputs` is provided, a dictionary with representations otherwise. """ del train del debug if self.max_output_stride not in [4, 8, 16, 32]: raise ValueError('Only supports output strides of [4, 8, 16, 32]') blocks, bottleneck = _BLOCK_SIZE_OPTIONS[self.num_layers] width = int(64 * self.width_factor) # Root block. x = StdConv(width, (7, 7), (2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(num_groups=self.gn_num_groups, epsilon=1e-4, name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') representations = {'stem': x} # Stages. x = ResNetStage(blocks[0], width, first_stride=(1, 1), bottleneck=bottleneck, gn_num_groups=self.gn_num_groups, name='block1')(x) stride = 4 for i, block_size in enumerate(blocks[1:], 1): max_stride_reached = self.max_output_stride <= stride x = ResNetStage(block_size, width * 2**i, first_stride=(2, 2) if not max_stride_reached else (1, 1), first_dilation=(2, 2) if max_stride_reached else (1, 1), bottleneck=bottleneck, gn_num_groups=self.gn_num_groups, name=f'block{i + 1}')(x) if not max_stride_reached: stride *= 2 representations[f'stage_{i + 1}'] = x # Head. x = jnp.mean(x, axis=(1, 2)) x = IdentityLayer(name='pre_logits')(x) representations['pre_logits'] = x x = nn.Dense(self.num_outputs, kernel_init=nn.initializers.zeros, name='head')(x) return x, representations
def __call__(self, inputs, *, train): x = inputs # (Possibly partial) ResNet root. if self.resnet is not None: width = int(64 * self.resnet.width_factor) # Root block. x = models_resnet.StdConv(features=width, kernel_size=(7, 7), strides=(2, 2), use_bias=False, name='conv_root')(x) x = nn.GroupNorm(name='gn_root')(x) x = nn.relu(x) x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding='SAME') # ResNet stages. if self.resnet.num_layers: x = models_resnet.ResNetStage( block_size=self.resnet.num_layers[0], nout=width, first_stride=(1, 1), name='block1')(x) for i, block_size in enumerate(self.resnet.num_layers[1:], 1): x = models_resnet.ResNetStage(block_size=block_size, nout=width * 2**i, first_stride=(2, 2), name=f'block{i + 1}')(x) n, h, w, c = x.shape # We can merge s2d+emb into a single conv; it's the same. x = nn.Conv(features=self.hidden_size, kernel_size=self.patches.size, strides=self.patches.size, padding='VALID', name='embedding')(x) # Here, x is a grid of embeddings. # Transformer. n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # If we want to add a class token, add it here. if self.classifier == 'token': cls = self.param('cls', nn.initializers.zeros, (1, 1, c)) cls = jnp.tile(cls, [n, 1, 1]) x = jnp.concatenate([cls, x], axis=1) x = Encoder(name='Transformer', **self.transformer)(x, train=train) if self.classifier == 'token': x = x[:, 0] elif self.classifier == 'gap': x = jnp.mean(x, axis=list(range(1, x.ndim - 1))) # (1,) or (1,2) else: raise ValueError(f'Invalid classifier={self.classifier}') if self.representation_size is not None: x = nn.Dense(features=self.representation_size, name='pre_logits')(x) x = nn.tanh(x) else: x = IdentityLayer(name='pre_logits')(x) if self.num_classes: x = nn.Dense(features=self.num_classes, name='head', kernel_init=nn.initializers.zeros)(x) return x
def __call__( self, inputs, ): """Applies ResNet model. Number of residual blocks inferred from hparams.""" num_classes = self.num_classes hparams = self.hparams num_filters = self.num_filters dtype = self.dtype assert hparams.act_function in act_function_zoo.keys( ), 'Activation function type is not supported.' x = aqt_flax_layers.ConvAqt( features=num_filters, kernel_size=(7, 7), strides=(2, 2), padding=[(3, 3), (3, 3)], use_bias=False, dtype=dtype, name='init_conv', train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.conv_init, )( inputs) x = nn.BatchNorm( use_running_average=not self.train, momentum=0.9, epsilon=1e-5, dtype=dtype, name='init_bn')( x) if hparams.act_function == 'relu': x = nn.relu(x) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') else: # TODO(yichi): try adding other activation functions here # Use avg pool so that for binary nets, the distribution is symmetric. x = nn.avg_pool(x, (3, 3), strides=(2, 2), padding='SAME') filter_multiplier = hparams.filter_multiplier for i, block_hparams in enumerate(hparams.residual_blocks): proj = block_hparams.conv_proj # For projection layers (unless it is the first layer), strides = (2, 2) if i > 0 and proj is not None: filter_multiplier *= 2 strides = (2, 2) else: strides = (1, 1) x = ResidualBlock( filters=int(num_filters * filter_multiplier), hparams=block_hparams, quant_context=self.quant_context, strides=strides, train=self.train, dtype=dtype)( x) if hparams.act_function == 'none': # The DenseAQT below is not binarized. # If removing the activation functions, there will be no act function # between the last residual block and the dense layer. # So add a ReLU in that case. # TODO(yichi): try BPReLU x = nn.relu(x) else: pass x = jnp.mean(x, axis=(1, 2)) x = aqt_flax_layers.DenseAqt( features=num_classes, dtype=dtype, train=self.train, quant_context=self.quant_context, paxis_name='batch', hparams=hparams.dense_layer, )(x, padding_mask=None) x = jnp.asarray(x, dtype) output = nn.log_softmax(x) return output