def transition_block(x, reduction, name): """A transition block. # Arguments x: input tensor.` reduction: float, compression rate at transition layers. name: string, block label. # Returns output tensor for the block. """ bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 # x = layers.BatchNormalization(axis=bn_axis, epsilon=1.001e-5, # name=name + '_bn')(x) x = GroupNormalization(axis=GN_AXIS, groups=4, scale=False, name=name + '_gn')(x) x = layers.Activation('relu', name=name + '_relu')(x) x = layers.Conv2D(int(backend.int_shape(x)[bn_axis] * reduction), 1, use_bias=False, name=name + '_conv')(x) x = layers.AveragePooling2D(2, strides=2, name=name + '_pool')(x) # squeeze and excite block x = squeeze_excite_block(x) return x
def conv_block(prev, num_filters, kernel=(3, 3), strides=(1, 1), act='relu', prefix=None): name = None if prefix is not None: name = prefix + '_conv' conv = Conv2D(num_filters, kernel, padding='same', kernel_initializer='he_normal', strides=strides, name=name)(prev) if prefix is not None: name = prefix + '_norm' conv = GroupNormalization(name=name, axis=GN_AXIS)(conv) if prefix is not None: name = prefix + '_act' conv = Activation(act, name=name)(conv) return conv
def conv2d_gn(x, filters, kernel_size, strides=1, padding='same', activation='relu', use_bias=False, name=None): """Utility function to apply conv + GN. # Arguments x: input tensor. filters: filters in `Conv2D`. kernel_size: kernel size as in `Conv2D`. strides: strides in `Conv2D`. padding: padding mode in `Conv2D`. activation: activation in `Conv2D`. use_bias: whether to use a bias in `Conv2D`. name: name of the ops; will become `name + '_ac'` for the activation and `name + '_gn'` for the batch norm layer. # Returns Output tensor after applying `Conv2D` and `GroupNormalization`. """ x = layers.Conv2D(filters, kernel_size, strides=strides, padding=padding, use_bias=use_bias, name=name)(x) if not use_bias: gn_axis = 1 if backend.image_data_format() == 'channels_first' else 3 gn_name = None if name is None else name + '_gn' # try: # x = GroupNormalization(axis=gn_axis, groups=32, # scale=False, # name=gn_name)(x) # except: x = GroupNormalization(axis=gn_axis, groups=filters // 4, scale=False, name=gn_name)(x) if activation is not None: ac_name = None if name is None else name + '_ac' x = layers.Activation(activation, name=ac_name)(x) return x
def get_densenet121_unet_sigmoid_gn(input_shape=(CONFIG.img_h, CONFIG.img_w, CONFIG.img_c), output_channels=1, weights='imagenet'): blocks = [6, 12, 24, 16] img_input = Input(input_shape) x = ZeroPadding2D(padding=((3, 3), (3, 3)))(img_input) x = Conv2D(64, 7, strides=2, use_bias=False, name='conv1/conv')(x) x = GroupNormalization(axis=GN_AXIS, groups=16, scale=False, name='conv1/gn')(x) x = Activation('relu', name='conv1/relu')(x) conv1 = x x = ZeroPadding2D(padding=((1, 1), (1, 1)))(x) x = MaxPooling2D(3, strides=2, name='pool1')(x) x = dense_block(x, blocks[0], name='conv2') conv2 = x x = transition_block(x, 0.5, name='pool2') x = dense_block(x, blocks[1], name='conv3') conv3 = x x = transition_block(x, 0.5, name='pool3') x = dense_block(x, blocks[2], name='conv4') conv4 = x x = transition_block(x, 0.5, name='pool4') x = dense_block(x, blocks[3], name='conv5') x = GroupNormalization(axis=GN_AXIS, groups=32, scale=False, name='conv5/gn')(x) conv5 = x # squeeze and excite block conv5 = squeeze_excite_block(conv5) conv6 = conv_block(UpSampling2D()(conv5), 320) conv6 = concatenate([conv6, conv4], axis=-1) conv6 = conv_block(conv6, 320) conv7 = conv_block(UpSampling2D()(conv6), 256) conv7 = concatenate([conv7, conv3], axis=-1) conv7 = conv_block(conv7, 256) conv8 = conv_block(UpSampling2D()(conv7), 128) conv8 = concatenate([conv8, conv2], axis=-1) conv8 = conv_block(conv8, 128) conv9 = conv_block(UpSampling2D()(conv8), 96) conv9 = concatenate([conv9, conv1], axis=-1) conv9 = conv_block(conv9, 96) conv10 = conv_block(UpSampling2D()(conv9), 64) conv10 = conv_block(conv10, 64) res = Conv2D(output_channels, (1, 1), activation='sigmoid')(conv10) model = Model(img_input, res) if weights == 'imagenet': densenet = DenseNet121(input_shape=(input_shape[0], input_shape[1], 3), weights=weights, include_top=False) print("Loading imagenet weights.") for i in tqdm(range(2, len(densenet.layers) - 1)): model.layers[i].set_weights(densenet.layers[i].get_weights()) model.layers[i].trainable = False return model