def MakeMain(input_shape): # the number of output channels depends on the number of input channels return stax.serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu, stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu, stax.Conv(input_shape[3], (1, 1)), stax.BatchNorm())
def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False): """WideResnet convolutational block.""" main = stax.Serial(stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME')) shortcut = stax.Identity() if not channel_mismatch else stax.Conv( channels, (3, 3), strides, padding='SAME') return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum())
def IdentityBlock(kernel_size, filters): """ResNet identical size block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1)), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, stax.Identity()), stax.FanInSum(), stax.Relu())
def ConvBlock(kernel_size, filters, strides): """ResNet convolutional striding block.""" ks = kernel_size filters1, filters2, filters3 = filters main = stax.Serial(stax.Conv(filters1, (1, 1), strides), stax.BatchNorm(), stax.Relu(), stax.Conv(filters2, (ks, ks), padding='SAME'), stax.BatchNorm(), stax.Relu(), stax.Conv(filters3, (1, 1)), stax.BatchNorm()) shortcut = stax.Serial(stax.Conv(filters3, (1, 1), strides), stax.BatchNorm()) return stax.Serial(stax.FanOut(), stax.Parallel(main, shortcut), stax.FanInSum(), stax.Relu())
def Resnet50(hidden_size=64, num_output_classes=1001): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. Returns: The ResNet model with the given layer and output sizes. """ return stax.serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu, stax.MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size]), IdentityBlock(3, [hidden_size, hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size]), stax.AvgPool((7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def WideResnet(num_blocks=3, hidden_size=64, num_output_classes=10): """WideResnet from https://arxiv.org/pdf/1605.07146.pdf. Args: num_blocks: int, number of blocks in a group. hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: int, number of classes to distinguish. Returns: The WideResnet model with given layer and output sizes. """ return stax.serial(stax.Conv(hidden_size, (3, 3), padding='SAME'), WideResnetGroup(num_blocks, hidden_size), WideResnetGroup(num_blocks, hidden_size * 2, (2, 2)), WideResnetGroup(num_blocks, hidden_size * 4, (2, 2)), stax.BatchNorm(), stax.Relu, stax.AvgPool((8, 8)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax)
def Resnet50(hidden_size=64, num_output_classes=1001, mode='train'): """ResNet. Args: hidden_size: the size of the first hidden layer (multiplied later). num_output_classes: how many classes to distinguish. mode: whether we are training or evaluating or doing inference. Returns: The ResNet model with the given layer and output sizes. """ del mode return stax.Serial( stax.Conv(hidden_size, (7, 7), (2, 2), 'SAME'), stax.BatchNorm(), stax.Relu(), stax.MaxPool(pool_size=(3, 3), strides=(2, 2)), ConvBlock(3, [hidden_size, hidden_size, 4 * hidden_size], (1, 1)), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), IdentityBlock(3, [hidden_size, hidden_size, 4 * hidden_size]), ConvBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size], (2, 2)), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), IdentityBlock(3, [2 * hidden_size, 2 * hidden_size, 8 * hidden_size]), ConvBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size], (2, 2)), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), IdentityBlock(3, [4 * hidden_size, 4 * hidden_size, 16 * hidden_size]), ConvBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size], (2, 2)), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), IdentityBlock(3, [8 * hidden_size, 8 * hidden_size, 32 * hidden_size]), stax.AvgPool(pool_size=(7, 7)), stax.Flatten(), stax.Dense(num_output_classes), stax.LogSoftmax())