def identity_block(inputs): main = Sequential(Conv(filters1, (1, 1)), BatchNorm(), relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), relu, Conv(inputs.shape[3], (1, 1)), BatchNorm()) return relu(sum((main(inputs), inputs)))
def conv_block(inputs): main = Sequential(Conv(filters1, (1, 1), strides), BatchNorm(), relu, Conv(filters2, (ks, ks), padding='SAME'), BatchNorm(), relu, Conv(filters3, (1, 1)), BatchNorm()) shortcut = Sequential(Conv(filters3, (1, 1), strides), BatchNorm()) return relu(sum((main(inputs), shortcut(inputs))))
def test_BatchNorm_shape_NCHW(center, scale): input_shape = (4, 5, 6, 7) batch_norm = BatchNorm(axis=(0, 2, 3), center=center, scale=scale) inputs = random_inputs(input_shape) params = batch_norm.init_parameters(PRNGKey(0), inputs) out = batch_norm.apply(params, inputs) assert out.shape == input_shape if center: assert params.beta.shape == (5, ) if scale: assert params.gamma.shape == (5, )
def ResNet50(num_classes): return Sequential( GeneralConv(('HWCN', 'OIHW', 'NHWC'), 64, (7, 7), (2, 2), 'SAME'), BatchNorm(), relu, MaxPool((3, 3), strides=(2, 2)), ConvBlock(3, [64, 64, 256], strides=(1, 1)), IdentityBlock(3, [64, 64]), IdentityBlock(3, [64, 64]), ConvBlock(3, [128, 128, 512]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), IdentityBlock(3, [128, 128]), ConvBlock(3, [256, 256, 1024]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), IdentityBlock(3, [256, 256]), ConvBlock(3, [512, 512, 2048]), IdentityBlock(3, [512, 512]), IdentityBlock(3, [512, 512]), AvgPool((7, 7)), flatten, Dense(num_classes), logsoftmax)