def AvgPool(window_shape, strides=None, padding=Padding.VALID.value): """Layer construction function for a 2D average pooling layer. Based on `jax.experimental.stax.AvgPool`. Has a similar API apart from: Args: padding: in addition to `VALID` and `SAME' padding, supports `CIRCULAR`, not available in `jax.experimental.stax.GeneralConv`. """ strides = strides or (1, ) * len(window_shape) padding = Padding(padding) if padding == Padding.CIRCULAR: init_fun, _ = stax.AvgPool(window_shape, strides, Padding.SAME.value) _, apply_fun_0 = stax.AvgPool(window_shape, strides, Padding.VALID.value) def apply_fun(params, inputs, **kwargs): inputs = _same_pad_for_filter_shape(inputs, window_shape, strides, (1, 2), 'wrap') res = apply_fun_0(params, inputs, **kwargs) return res else: init_fun, apply_fun = stax.AvgPool(window_shape, strides, padding.value) def ker_fun(kernels): """Kernel transformation.""" var1, nngp, var2, ntk, is_gaussian, is_height_width = kernels if not is_height_width: window_shape_nngp = window_shape[::-1] strides_nngp = strides[::-1] else: window_shape_nngp = window_shape strides_nngp = strides nngp = _average_pool_nngp_6d(nngp, window_shape_nngp, strides_nngp, padding) ntk = _average_pool_nngp_6d(ntk, window_shape_nngp, strides_nngp, padding) if var2 is None: var1 = _diagonal_nngp_6d(nngp) else: # TODO(romann) warnings.warn( 'Pooling for different inputs `x1` and `x2` is not ' 'implemented and will only work if there are no ' 'nonlinearities in the network anywhere after the pooling ' 'layer. `var1` and `var2` will have wrong values. ' 'This will be fixed soon.') return Kernel(var1, nngp, var2, ntk, is_gaussian, is_height_width) setattr(ker_fun, _USE_POOLING, True) return init_fun, apply_fun, ker_fun
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)
def __init__(self, num_classes=100, encoding=True): blocks = [ stax.GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), self.ConvBlock(3, [64, 64, 256], strides=(1, 1)), self.IdentityBlock(3, [64, 64]), self.IdentityBlock(3, [64, 64]), self.ConvBlock(3, [128, 128, 512]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.IdentityBlock(3, [128, 128]), self.ConvBlock(3, [256, 256, 1024]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.IdentityBlock(3, [256, 256]), self.ConvBlock(3, [512, 512, 2048]), self.IdentityBlock(3, [512, 512]), self.IdentityBlock(3, [512, 512]), stax.AvgPool((7, 7)) ] if not encoding: blocks.append(stax.Flatten) blocks.append(stax.Dense(num_classes)) self.model = stax.serial(*blocks)
def wide_resnet(n, k, num_classes): """Original WRN from paper and previous experiments.""" return stax.serial(stax.Conv(16, (3, 3), padding='SAME'), wide_resnet_group(n, 16 * k, strides=(1, 1)), wide_resnet_group(n, 32 * k, strides=(2, 2)), wide_resnet_group(n, 64 * k, strides=(2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_classes))
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten, stax.Dense(num_output_classes), stax.LogSoftmax)