def _create_octconv_residual_block(inputs, ch, N, alpha): high, low = inputs # OctConv with skip connections for i in range(N): # adjust channels if i == 0: skip_high = layers.Conv2D(int(ch * (1 - alpha)), 1)(high) skip_high = layers.BatchNormalization()(skip_high) skip_high = layers.Activation("relu")(skip_high) skip_low = layers.Conv2D(int(ch * alpha), 1)(low) skip_low = layers.BatchNormalization()(skip_low) skip_low = layers.Activation("relu")(skip_low) else: skip_high, skip_low = high, low high, low = oct_conv.OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high = layers.Add()([high, skip_high]) low = layers.Add()([low, skip_low]) return [high, low]
def create_octconvmlp_wide_resnet(alpha, N=4, k=10): """ Create OctConvmlp Wide ResNet(N=4, k=10) """ # Input input = layers.Input((32, 32, 3)) # downsampling for lower low = layers.AveragePooling2D(2)(input) # 16 channels block high, low = oct_conv.OctConv2D(filters=16, alpha=alpha)([input, low]) high = layers.Conv2D(int(16 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(16 * alpha), 1)(low) high = layers.Conv2D(int(16 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(16 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # 1st block high, low = _create_octconvmlp_residual_block([high, low], 16 * k, N, alpha) # 2nd block high = layers.AveragePooling2D(2)(high) low = layers.AveragePooling2D(2)(low) high, low = _create_octconvmlp_residual_block([high, low], 32 * k, N, alpha) # 3rd block high = layers.AveragePooling2D(2)(high) low = layers.AveragePooling2D(2)(low) high, low = _create_octconvmlp_residual_block([high, low], 64 * k, N, alpha) # concat high = layers.AveragePooling2D(2)(high) x = layers.Concatenate()([high, low]) x = layers.Conv2D(64 * k, 1)(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) # FC x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(10, activation="softmax")(x) model = Model(input, x) return model
def create_octcovmlp_vgg(alpha): # input input = layers.Input((32, 32, 3)) # downsampling for lower low = layers.AveragePooling2D(2)(input) # Block 1 high, low = oct_conv.OctConv2D(filters=64, alpha=alpha)([input, low]) high = layers.Conv2D(int(64 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(64 * alpha), 1)(low) high = layers.Conv2D(int(64 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(64 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=64, alpha=alpha)([high, low]) high = layers.Conv2D(int(64 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(64 * alpha), 1)(low) high = layers.Conv2D(int(64 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(64 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # Block 2 high, low = oct_conv.OctConv2D(filters=128, alpha=alpha)([high, low]) high = layers.Conv2D(int(128 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(128 * alpha), 1)(low) high = layers.Conv2D(int(128 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(128 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=128, alpha=alpha)([high, low]) high = layers.Conv2D(int(128 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(128 * alpha), 1)(low) high = layers.Conv2D(int(128 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(128 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high = layers.MaxPooling2D((2, 2), strides=(2, 2))(high) low = layers.MaxPooling2D((2, 2), strides=(2, 2))(low) # Block 3 high, low = oct_conv.OctConv2D(filters=256, alpha=alpha)([high, low]) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=256, alpha=alpha)([high, low]) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=256, alpha=alpha)([high, low]) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.Conv2D(int(256 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(256 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # Block 4 high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high = layers.MaxPooling2D((2, 2), strides=(2, 2))(high) low = layers.MaxPooling2D((2, 2), strides=(2, 2))(low) # Block 5 high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = oct_conv.OctConv2D(filters=512, alpha=alpha)([high, low]) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.Conv2D(int(512 * (1 - alpha)), 1)(high) low = layers.Conv2D(int(512 * alpha), 1)(low) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # concat high = layers.AveragePooling2D(2)(high) x = layers.Concatenate()([high, low]) x = layers.Conv2D(4096, 1)(x) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) # FC x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(10, activation="softmax")(x) model = Model(input, x) return model