def WideResnet(block_size, k, num_classes, channels=1024, nonlinearity='relu', parameterization='standard', order=None): if nonlinearity == 'relu': nonlin = Relu elif nonlinearity == 'swish': nonlin = Swish return jax_stax.serial( MyConv(channels, (3, 3), padding='SAME', parameterization=parameterization, order=order), WideResnetGroup(block_size, int(16 * k), nonlin=nonlin, parameterization=parameterization, order=order), WideResnetGroup(block_size, int(32 * k), (2, 2), nonlin=nonlin, parameterization=parameterization, order=order), WideResnetGroup(block_size, int(64 * k), (2, 2), nonlin=nonlin, parameterization=parameterization, order=order), AvgPool((8, 8)), Flatten, MyDense(num_classes, parameterization=parameterization, order=order))
def ResNet50(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax, )
def LeNet5(num_classes): return stax.serial( GeneralConv(('HWCN','OIHW','NHWC'), 64, (7,7), (2,2), 'SAME'), BatchNorm(), Relu, AvgPool((3,3)), Conv(16, (5,5), strides = (1,1),padding="SAME"), BatchNorm(), Relu, AvgPool((3,3)), Flatten, Dense(num_classes*10), Dense(num_classes*5), Dense(num_classes), LogSoftmax )
def ResNet(num_classes): return stax.serial( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), convBlock(3, [64, 64, 256]), identityBlock(3, [64, 64]), identityBlock(3, [64, 64]), convBlock(3, [128, 128, 512]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), identityBlock(3, [128, 128]), convBlock(3, [256, 256, 1024]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), identityBlock(3, [256, 256]), convBlock(3, [512, 512, 2048]), identityBlock(3, [512, 512]), identityBlock(3, [512, 512]), AvgPool((7, 7)), Flatten, Dense(num_classes), LogSoftmax)
def ResNet(num_classes): return stax.serial( GeneralConv(("HWCN", "OIHW", "NHWC"), 64, (7, 7), (2, 2), "SAME"), BatchNorm(), Relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [4, 4, 4], strides=(1, 1)), IdentityBlock(3, [4, 4]), AvgPool((3, 3)), Flatten, Dense(num_classes), LogSoftmax, )
def ResNet(block, expansion, layers, normalization_method=None, width_per_group=64, actfn=stax.Relu): norm_layer = Identity if normalization_method == "group_norm": norm_layer = GroupNorm(32) elif normalization_method == "batch_norm": norm_layer = BatchNorm() base_width = width_per_group def _make_layer(block, planes, blocks, stride=1): downsample = None if stride != 1: downsample = stax.serial( Conv(planes * expansion, (1, 1), strides=(stride, stride), bias=False), norm_layer, ) layers = [] layers.append(block(planes, stride, downsample, base_width, norm_layer)) for _ in range(1, blocks): layers.append( block(planes, base_width=base_width, norm_layer=norm_layer, actfn=actfn)) return stax.serial(*layers) return [ Conv(64, (3, 3), strides=(1, 1), padding="SAME", bias=False), norm_layer, actfn, # MaxPool((3, 3), strides=(2, 2), padding="SAME"), _make_layer(block, 64, layers[0]), _make_layer(block, 128, layers[1], stride=2), _make_layer(block, 256, layers[2], stride=2), _make_layer(block, 512, layers[3], stride=2), AvgPool((4, 4)), Flatten, ]