def _create_octconv_residual_block(inputs, ch, N, alpha): high, low = inputs # OctConv with skip connections for i in range(N): # adjust channels if i == 0: skip_high = layers.Conv2D(int(ch * (1 - alpha)), 1)(high) skip_high = layers.BatchNormalization()(skip_high) skip_high = layers.Activation("relu")(skip_high) skip_low = layers.Conv2D(int(ch * alpha), 1)(low) skip_low = layers.BatchNormalization()(skip_low) skip_low = layers.Activation("relu")(skip_low) else: skip_high, skip_low = high, low high, low = OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high = layers.Add()([high, skip_high]) low = layers.Add()([low, skip_low]) return [high, low]
def _create_octconv_residual_block(inputs, ch, N, alpha): # adjust channels high, low = OctConv2D(filters=ch, alpha=alpha)(inputs) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # OctConv with skip connections for i in range(N - 1): skip_high, skip_low = [high, low] high, low = OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high, low = OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) high = layers.Add()([high, skip_high]) low = layers.Add()([low, skip_low]) return [high, low]
def create_octconv_wide_resnet_5(alpha, N=4, k=10): """ Create OctConv Wide ResNet(N=4, k=10) """ # Input input = layers.Input((32, 32, 3)) # downsampling for lower low = layers.AveragePooling2D(2)(input) # 16 channels block high, low = OctConv2D(filters=16, alpha=alpha)([input, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # 1st block high, low = _create_octconv_residual_block([high, low], 16 * k, N, alpha) # 2nd block high = layers.AveragePooling2D(2)(high) low = layers.AveragePooling2D(2)(low) high, low = _create_octconv_residual_block([high, low], 32 * k, N, alpha) # 3rd block high = layers.AveragePooling2D(2)(high) low = layers.AveragePooling2D(2)(low) high, low = _create_octconv_residual_block([high, low], 64 * k, N - 1, alpha) # 3rd block Last x = _create_octconv_last_residual_block([high, low], 64 * k, alpha) # FC x = layers.GlobalAveragePooling2D()(x) x = layers.Dense(5, activation="softmax")(x) model = Model(input, x) return model
def _create_octconv_last_residual_block(inputs, ch, alpha): # Last layer for octconv resnets high, low = inputs # OctConv high, low = OctConv2D(filters=ch, alpha=alpha)([high, low]) high = layers.BatchNormalization()(high) high = layers.Activation("relu")(high) low = layers.BatchNormalization()(low) low = layers.Activation("relu")(low) # Last conv layers = alpha_out = 0 : vanila Conv2D # high -> high high_to_high = layers.Conv2D(ch, 3, padding="same")(high) # low -> high low_to_high = layers.Conv2D(ch, 3, padding="same")(low) low_to_high = layers.Lambda(lambda x: K.repeat_elements( K.repeat_elements(x, 2, axis=1), 2, axis=2))(low_to_high) x = layers.Add()([high_to_high, low_to_high]) x = layers.BatchNormalization()(x) x = layers.Activation("relu")(x) return x