def MakeMain(input_shape): # the number of output channels depends on the number of input channels return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.serial(stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum)
def ConvBlock(self, kernel_size, filters, strides=(2, 2)): filters1, filters2, filters3 = filters Main = stax.serial( stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (kernel_size, kernel_size), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) Shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(Main, Shortcut), stax.FanInSum, stax.Relu)
def create_double_conv(d: int, out_channels: int, mid_channels: int, batch_norm: bool, activation: Callable): return stax.serial( CONV[d](mid_channels, (3, ) * d, padding='same'), stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity, activation, CONV[d](out_channels, (3, ) * d, padding='same'), stax.BatchNorm(axis=tuple(range(d + 1))) if batch_norm else stax.Identity, activation, )
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.serial(stax.FanOut(2), stax.parallel(main, shortcut), stax.FanInSum, stax.Relu)
def wide_resnet_block(num_channels, strides=(1, 1), channel_mismatch=False): """Wide ResNet block.""" pre = stax.serial(stax.BatchNorm(), stax.Relu) mid = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(num_channels, (3, 3), strides=(1, 1), padding='SAME')) if channel_mismatch: cut = stax.serial( pre, stax.Conv(num_channels, (3, 3), strides, padding='SAME')) else: cut = stax.Identity return stax.serial(stax.FanOut(2), stax.parallel(mid, cut), stax.FanInSum)
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def _batch_norm_internal(batchnorm, axis=(0, 1, 2)): """Layer constructor for a stax.BatchNorm layer with dummy kernel computation. Do not use kernels for architectures that include this function.""" bn = stax.BatchNorm() init_fn, apply_fn = bn kernel_fn = lambda kernels: kernels return init_fn, apply_fn, kernel_fn
def __init__(self, num_classes=100, encoding=True): blocks = [ stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), self.ConvBlock(3, [64, 64, 256], strides=(1, 1)), self.IdentityBlock(3, [64, 64]), self.IdentityBlock(3, [64, 64]), self.ConvBlock(3, [128, 128, 512]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.ConvBlock(3, [256, 256, 1024]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.ConvBlock(3, [512, 512, 2048]), self.IdentityBlock(3, [512, 512]), self.IdentityBlock(3, [512, 512]), stax.AvgPool((7, 7)) ] if not encoding: blocks.append(stax.Flatten) blocks.append(stax.Dense(num_classes)) self.model = stax.serial(*blocks)
def wide_resnet(n, k, num_classes): """Original WRN from paper and previous experiments.""" return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), wide_resnet_group(n, 16 * k, strides=(1, 1)), wide_resnet_group(n, 32 * k, strides=(2, 2)), wide_resnet_group(n, 64 * k, strides=(2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_classes))
def create_model(nbin, nhidden, nlayer): layers = [] for i in range(nlayer): layers.extend([ stax.Dense(nhidden), stax.LeakyRelu, stax.BatchNorm(axis=(0, 1)), ]) layers.extend([stax.Dense(nbin), stax.Softmax]) return stax.serial(*layers)
def testBatchNormShapeNHWC(self): init_fun, apply_fun = stax.BatchNorm(axis=(0, 1, 2)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (7, )) self.assertEqual(gamma.shape, (7, )) self.assertEqual(out_shape, out.shape)
def testBatchNormNoScaleOrCenter(self): key = random.PRNGKey(0) axes = (0, 1, 2) init_fun, apply_fun = stax.BatchNorm(axis=axes, center=False, scale=False) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(key, input_shape) out = apply_fun(params, inputs) means = onp.mean(out, axis=(0, 1, 2)) std_devs = onp.std(out, axis=(0, 1, 2)) assert onp.allclose(means, onp.zeros_like(means), atol=1e-4) assert onp.allclose(std_devs, onp.ones_like(std_devs), atol=1e-4)
def testBatchNormShapeNCHW(self): # Regression test for https://github.com/google/jax/issues/461 init_fun, apply_fun = stax.BatchNorm(axis=(0, 2, 3)) input_shape = (4, 5, 6, 7) inputs = random_inputs(onp.random.RandomState(0), input_shape) out_shape, params = init_fun(input_shape) out = apply_fun(params, inputs) self.assertEqual(out_shape, input_shape) beta, gamma = params self.assertEqual(beta.shape, (5, )) self.assertEqual(gamma.shape, (5, )) self.assertEqual(out_shape, out.shape)
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def dense_net(in_channels: int, out_channels: int, layers: tuple or list, batch_norm=False, activation='ReLU') -> StaxNet: activation = { 'ReLU': stax.Relu, 'Sigmoid': stax.Sigmoid, 'tanh': stax.Tanh }[activation] stax_layers = [] for neuron_count in layers: stax_layers.append(stax.Dense(neuron_count)) stax_layers.append(activation) if batch_norm: stax_layers.append(stax.BatchNorm(axis=(0, ))) stax_layers.append(stax.Dense(out_channels)) net_init, net_apply = stax.serial(*stax_layers) net = StaxNet(net_init, net_apply, (-1, in_channels)) net.initialize() return net